mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-22 06:13:10 -06:00
Handle deleted devices correctly over federation
This commit is contained in:
parent
a53030792b
commit
64db98aab5
|
|
@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
|
||||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||||
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||||
|
|
||||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
|
|
@ -192,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
||||||
"stream_id": event.StreamID,
|
"stream_id": event.StreamID,
|
||||||
"prev_ids": event.PrevID,
|
"prev_ids": event.PrevID,
|
||||||
"display_name": event.DeviceDisplayName,
|
"display_name": event.DeviceDisplayName,
|
||||||
|
"deleted": event.Deleted,
|
||||||
}).Info("DeviceListUpdater.Update")
|
}).Info("DeviceListUpdater.Update")
|
||||||
|
|
||||||
// if we haven't missed anything update the database and notify users
|
// if we haven't missed anything update the database and notify users
|
||||||
if exists {
|
if exists {
|
||||||
|
k := event.Keys
|
||||||
|
if event.Deleted {
|
||||||
|
k = nil
|
||||||
|
}
|
||||||
keys := []api.DeviceMessage{
|
keys := []api.DeviceMessage{
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: api.DeviceKeys{
|
||||||
DeviceID: event.DeviceID,
|
DeviceID: event.DeviceID,
|
||||||
DisplayName: event.DeviceDisplayName,
|
DisplayName: event.DeviceDisplayName,
|
||||||
KeyJSON: event.Keys,
|
KeyJSON: k,
|
||||||
UserID: event.UserID,
|
UserID: event.UserID,
|
||||||
},
|
},
|
||||||
StreamID: event.StreamID,
|
StreamID: event.StreamID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys)
|
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||||
}
|
}
|
||||||
|
|
@ -362,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys.
|
||||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||||
d.storedKeys = append(d.storedKeys, keys...)
|
d.storedKeys = append(d.storedKeys, keys...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
||||||
maxStreamID = m.StreamID
|
maxStreamID = m.StreamID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.Devices = msgs
|
// remove deleted devices
|
||||||
|
var result []api.DeviceMessage
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m.KeyJSON == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, m)
|
||||||
|
}
|
||||||
|
res.Devices = result
|
||||||
res.StreamID = maxStreamID
|
res.StreamID = maxStreamID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -408,7 +416,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
if res.DeviceKeys[userID] == nil {
|
if res.DeviceKeys[userID] == nil {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
if len(key.KeyJSON) == 0 {
|
||||||
|
continue // ignore deleted keys
|
||||||
|
}
|
||||||
// inject the display name
|
// inject the display name
|
||||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,9 @@ type Database interface {
|
||||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||||
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||||
|
|
||||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
|
||||||
const countStreamIDsForUserSQL = "" +
|
const countStreamIDsForUserSQL = "" +
|
||||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||||
|
|
||||||
|
const deleteAllDeviceKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
|
|
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
countStreamIDsForUserStmt *sql.Stmt
|
countStreamIDsForUserStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
|
||||||
return count == len(prevIDs), nil
|
return count == len(prevIDs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
for _, userID := range clearUserIDs {
|
||||||
|
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
|
||||||
const countStreamIDsForUserSQL = "" +
|
const countStreamIDsForUserSQL = "" +
|
||||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||||
|
|
||||||
|
const deleteAllDeviceKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer *sqlutil.TransactionWriter
|
||||||
|
|
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ type DeviceKeys interface {
|
||||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
|
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue