diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index c27e291fc..dd8fb7008 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -91,6 +91,9 @@ type DeviceListUpdaterDatabase interface { // PrevIDsExists returns true if all prev IDs exist for this user. PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error } // KeyChangeProducer is the interface for producers.KeyChange useful for testing. @@ -354,6 +357,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. keys := make([]api.DeviceMessage, len(res.Devices)) + existingKeys := make([]api.DeviceMessage, len(res.Devices)) for i, device := range res.Devices { keyJSON, err := json.Marshal(device.Keys) if err != nil { @@ -369,7 +373,21 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi KeyJSON: keyJSON, }, } + existingKeys[i] = api.DeviceMessage{ + DeviceKeys: api.DeviceKeys{ + UserID: res.UserID, + DeviceID: device.DeviceID, + }, + } } + // fetch what keys we had already and only emit changes + if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + // non-fatal, log and continue + util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf( + "failed to query device keys json for calculating diffs", + ) + } + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) if err != nil { return fmt.Errorf("failed to store remote device keys: %w", err) @@ -378,7 +396,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi if err != nil { return fmt.Errorf("failed to mark device list as fresh: %w", err) } - err = u.producer.ProduceKeyChanges(keys) + err = emitDeviceKeyChanges(u.producer, existingKeys, keys) if err != nil { return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index c42a7cdfc..56bb4888c 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -91,6 +91,10 @@ func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userI return d.prevIDsExist(userID, prevIDs), nil } +func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return nil +} + type roundTripper struct { fn func(*http.Request) (*http.Response, error) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 8904d4637..31fb12367 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -505,7 +505,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } return } - err = a.emitDeviceKeyChanges(existingKeys, keysToStore) + err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -550,7 +550,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error { +func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage) error { // find keys in new that are not in existing var keysAdded []api.DeviceMessage for _, newKey := range new { @@ -567,7 +567,7 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) keysAdded = append(keysAdded, newKey) } } - return a.Producer.ProduceKeyChanges(keysAdded) + return producer.ProduceKeyChanges(keysAdded) } func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {