diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6a737d0ad..3178a781b 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -17,94 +17,89 @@ package consumers import ( "context" "encoding/json" - "fmt" - "github.com/Shopify/sarama" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { ctx context.Context - consumer *internal.ContinualConsumer + jetstream nats.JetStreamContext db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName rsAPI roomserverAPI.RoomserverInternalAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. func NewKeyChangeConsumer( process *process.ProcessContext, cfg *config.KeyServer, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { - c := &KeyChangeConsumer{ - ctx: process.Context(), - consumer: &internal.ContinualConsumer{ - Process: process, - ComponentName: "federationapi/keychange", - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Consumer: kafkaConsumer, - PartitionStore: store, - }, + return &KeyChangeConsumer{ + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), queues: queues, db: store, serverName: cfg.Matrix.ServerName, rsAPI: rsAPI, } - c.consumer.ProcessMessage = c.onMessage - - return c } // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { - if err := t.consumer.Start(); err != nil { - return fmt.Errorf("t.consumer.Start: %w", err) - } - return nil + _, err := t.jetstream.Subscribe( + t.topic, t.onMessage, + nats.DeliverAll(), + ) + return err } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { - var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { - logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil - } - if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { - // This probably shouldn't happen but stops us from panicking if we come - // across an update that doesn't satisfy either types. - return nil - } - switch m.Type { - case api.TypeCrossSigningUpdate: - return t.onCrossSigningMessage(m) - case api.TypeDeviceKeyUpdate: - fallthrough - default: - return t.onDeviceKeyMessage(m) - } +func (t *KeyChangeConsumer) onMessage(msg *nats.Msg) { + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + var m api.DeviceMessage + if err := json.Unmarshal(msg.Data, &m); err != nil { + logrus.WithError(err).Errorf("failed to read device message from key change topic") + return true + } + if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { + // This probably shouldn't happen but stops us from panicking if we come + // across an update that doesn't satisfy either types. + return true + } + switch m.Type { + case api.TypeCrossSigningUpdate: + return t.onCrossSigningMessage(m) + case api.TypeDeviceKeyUpdate: + fallthrough + default: + return t.onDeviceKeyMessage(m) + } + }) + } -func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { if m.DeviceKeys == nil { - return nil + return true } logger := logrus.WithField("user_id", m.UserID) @@ -112,10 +107,10 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) if err != nil { logger.WithError(err).Error("Failed to extract domain from key change event") - return nil + return true } if originServerName != t.serverName { - return nil + return true } var queryRes roomserverAPI.QueryRoomsForUserResponse @@ -125,13 +120,13 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -149,24 +144,25 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { Keys: m.KeyJSON, } if edu.Content, err = json.Marshal(event); err != nil { - return err + return true } logrus.Infof("Sending device list update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } -func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") - return nil + return true } if host != gomatrixserverlib.ServerName(t.serverName) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. - return nil + return true } logger := logrus.WithField("user_id", output.UserID) @@ -177,13 +173,13 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -193,11 +189,12 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { } if edu.Content, err = json.Marshal(output); err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping") - return nil + return true } logger.Infof("Sending cross-signing update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } func prevID(streamID int) []int { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 25ea78274..e0990cec0 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -125,7 +125,10 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { }).Panicf("roomserver output log: remote peek event failure") return false } - + case api.OutputTypeNewInviteEvent: + log.WithField("type", output.Type).Debug( + "received new invite, send device keys", + ) default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 63387b9d8..a982d8009 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -92,7 +92,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -120,7 +120,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI, + base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 21a919f6a..3fa8d1f7a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,12 +19,10 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)