Handle deleted devices correctly over federation

This commit is contained in:
Kegan Dougal 2020-08-12 21:56:01 +01:00
parent a53030792b
commit 64db98aab5
8 changed files with 60 additions and 10 deletions

View file

@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
@ -192,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"stream_id": event.StreamID, "stream_id": event.StreamID,
"prev_ids": event.PrevID, "prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName, "display_name": event.DeviceDisplayName,
"deleted": event.Deleted,
}).Info("DeviceListUpdater.Update") }).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users // if we haven't missed anything update the database and notify users
if exists { if exists {
k := event.Keys
if event.Deleted {
k = nil
}
keys := []api.DeviceMessage{ keys := []api.DeviceMessage{
{ {
DeviceKeys: api.DeviceKeys{ DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID, DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName, DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys, KeyJSON: k,
UserID: event.UserID, UserID: event.UserID,
}, },
StreamID: event.StreamID, StreamID: event.StreamID,
}, },
} }
err = u.db.StoreRemoteDeviceKeys(ctx, keys) err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
} }
@ -362,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
}, },
} }
} }
err := u.db.StoreRemoteDeviceKeys(ctx, keys) err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
if err != nil { if err != nil {
return fmt.Errorf("failed to store remote device keys: %w", err) return fmt.Errorf("failed to store remote device keys: %w", err)
} }

View file

@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
d.storedKeys = append(d.storedKeys, keys...) d.storedKeys = append(d.storedKeys, keys...)
return nil return nil
} }

View file

@ -206,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
maxStreamID = m.StreamID maxStreamID = m.StreamID
} }
} }
res.Devices = msgs // remove deleted devices
var result []api.DeviceMessage
for _, m := range msgs {
if m.KeyJSON == nil {
continue
}
result = append(result, m)
}
res.Devices = result
res.StreamID = maxStreamID res.StreamID = maxStreamID
} }
@ -408,7 +416,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if res.DeviceKeys[userID] == nil { if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
for _, key := range keys { for _, key := range keys {
if len(key.KeyJSON) == 0 {
continue // ignore deleted keys
}
// inject the display name // inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`

View file

@ -43,8 +43,9 @@ type Database interface {
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)

View file

@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" + const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
return nil return nil
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {

View file

@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
return count == len(prevIDs), nil return count == len(prevIDs), nil
} }
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
}) })
} }

View file

@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" + const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {

View file

@ -38,6 +38,7 @@ type DeviceKeys interface {
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
} }
type KeyChanges interface { type KeyChanges interface {