From 54ceb50e88a9adc25f7973934b5fdc323cf9477a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 7 Aug 2020 10:43:36 +0100 Subject: [PATCH] Add display_name col to store remote device names Few other tweaks to make `Server correctly handles incoming m.device_list_update` pass. --- keyserver/internal/device_list_update.go | 5 ++++ keyserver/internal/internal.go | 13 +++++++--- .../storage/postgres/device_keys_table.go | 25 +++++++++++++------ .../storage/sqlite3/device_keys_table.go | 25 +++++++++++++------ sytest-whitelist | 1 + 5 files changed, 50 insertions(+), 19 deletions(-) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index ac6c89998..ec7dff560 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -158,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. if err != nil { return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) } + // if this is the first time we're hearing about this user, sync the device list manually. + if len(event.PrevID) == 0 { + exists = false + } util.GetLogger(ctx).WithFields(logrus.Fields{ "prev_ids_exist": exists, "user_id": event.UserID, "device_id": event.DeviceID, "stream_id": event.StreamID, "prev_ids": event.PrevID, + "display_name": event.DeviceDisplayName, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0a4c5defd..075622b7c 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if len(dk.KeyJSON) == 0 { continue // don't include blank keys } - // inject display name if known + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` - }{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) + }{displayName}) res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { @@ -261,7 +265,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } } - // TODO: set device display names when they are known // attempt to satisfy key queries from the local database first as we should get device updates pushed to us domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) @@ -294,6 +297,10 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, key := range keys { + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON } } diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d321860d4..b9d5d4c36 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( -- required in the spec because in the event of a missed update the server fetches the entire -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err @@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 15d9c775f..abe6636af 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err diff --git a/sytest-whitelist b/sytest-whitelist index 18978bbe6..a726e51e8 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -138,6 +138,7 @@ Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Local device key changes get to remote servers Local device key changes get to remote servers with correct prev_id +Server correctly handles incoming m.device_list_update Can add account data Can add account data to room Can get account data without syncing