diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 3dfec5b7e..51debad5c 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -43,10 +43,8 @@ const upsertKeyChangeSQL = "" + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(change_id) FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index eff371dab..10354b978 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -44,10 +44,8 @@ const upsertKeyChangeSQL = "" + " DO UPDATE SET change_id = change_id + 1" + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(change_id) FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index d63e4832b..97685cc04 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -17,7 +17,6 @@ package consumers import ( "context" "encoding/json" - "sync" "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" @@ -34,16 +33,14 @@ import ( // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI - partitionToOffset map[int32]int64 - partitionToOffsetMu sync.Mutex + ctx context.Context + keyChangeConsumer *internal.ContinualConsumer + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -69,16 +66,14 @@ func NewOutputKeyChangeEventConsumer( } s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - partitionToOffset: make(map[int32]int64), - partitionToOffsetMu: sync.Mutex{}, - notifier: notifier, - stream: stream, + ctx: process.Context(), + keyChangeConsumer: &consumer, + db: store, + serverName: serverName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } consumer.ProcessMessage = s.onMessage @@ -88,24 +83,10 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - offsets, err := s.keyChangeConsumer.StartOffsets() - s.partitionToOffsetMu.Lock() - for _, o := range offsets { - s.partitionToOffset[o.Partition] = o.Offset - } - s.partitionToOffsetMu.Unlock() - return err -} - -func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) { - s.partitionToOffsetMu.Lock() - defer s.partitionToOffsetMu.Unlock() - s.partitionToOffset[msg.Partition] = msg.Offset + return s.keyChangeConsumer.Start() } func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { - defer s.updateOffset(msg) - var m api.DeviceMessage if err := json.Unmarshal(msg.Value, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") @@ -118,15 +99,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } switch m.Type { case api.TypeCrossSigningUpdate: - return s.onCrossSigningMessage(m, msg.Offset) + return s.onCrossSigningMessage(m, m.DeviceChangeID) case api.TypeDeviceKeyUpdate: fallthrough default: - return s.onDeviceKeyMessage(m, msg.Offset) + return s.onDeviceKeyMessage(m, m.DeviceChangeID) } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { if m.DeviceKeys == nil { return nil } @@ -143,7 +124,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.StreamPosition(offset) + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { @@ -153,7 +134,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o return nil } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -167,7 +148,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.StreamPosition(offset) + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount {