From 64db98aab5d81494551ae20b1a4babecb0718b59 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 12 Aug 2020 21:56:01 +0100 Subject: [PATCH] Handle deleted devices correctly over federation --- keyserver/internal/device_list_update.go | 16 +++++++++++----- keyserver/internal/device_list_update_test.go | 2 +- keyserver/internal/internal.go | 14 +++++++++++++- keyserver/storage/interface.go | 5 +++-- keyserver/storage/postgres/device_keys_table.go | 12 ++++++++++++ keyserver/storage/shared/storage.go | 8 +++++++- keyserver/storage/sqlite3/device_keys_table.go | 12 ++++++++++++ keyserver/storage/tables/interface.go | 1 + 8 files changed, 60 insertions(+), 10 deletions(-) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 95b052452..573285e88 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface { MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) @@ -192,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. "stream_id": event.StreamID, "prev_ids": event.PrevID, "display_name": event.DeviceDisplayName, + "deleted": event.Deleted, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users if exists { + k := event.Keys + if event.Deleted { + k = nil + } keys := []api.DeviceMessage{ { DeviceKeys: api.DeviceKeys{ DeviceID: event.DeviceID, DisplayName: event.DeviceDisplayName, - KeyJSON: event.Keys, + KeyJSON: k, UserID: event.UserID, }, StreamID: event.StreamID, }, } - err = u.db.StoreRemoteDeviceKeys(ctx, keys) + err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) if err != nil { return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) } @@ -362,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi }, } } - err := u.db.StoreRemoteDeviceKeys(ctx, keys) + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) if err != nil { return fmt.Errorf("failed to store remote device keys: %w", err) } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index dcb981c4f..c42a7cdfc 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). Does not modify the stream ID for keys. -func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { d.storedKeys = append(d.storedKeys, keys...) return nil } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 356a55523..ef52d0141 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -206,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query maxStreamID = m.StreamID } } - res.Devices = msgs + // remove deleted devices + var result []api.DeviceMessage + for _, m := range msgs { + if m.KeyJSON == nil { + continue + } + result = append(result, m) + } + res.Devices = result res.StreamID = maxStreamID } @@ -408,7 +416,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } + for _, key := range keys { + if len(key.KeyJSON) == 0 { + continue // ignore deleted keys + } // inject the display name key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 2a60aaccc..0ec62f567 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -43,8 +43,9 @@ type Database interface { StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index b9d5d4c36..779d02c03 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { db *sql.DB upsertDeviceKeysStmt *sql.Stmt @@ -68,6 +71,7 @@ type deviceKeysStatements struct { selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { return nil, err } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx return nil } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 4279eae77..a4c35a4bd 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i return count == len(prevIDs), nil } -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) }) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index abe6636af..a4d71fe13 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -65,6 +68,7 @@ type deviceKeysStatements struct { selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index a4d5dede2..f97e871f6 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -38,6 +38,7 @@ type DeviceKeys interface { SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } type KeyChanges interface {