Add display_name col to store remote device names

Few other tweaks to make `Server correctly handles incoming m.device_list_update`
pass.
This commit is contained in:
Kegan Dougal 2020-08-07 10:43:36 +01:00
parent 77530ac501
commit 54ceb50e88
5 changed files with 50 additions and 19 deletions

View file

@ -158,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
if err != nil { if err != nil {
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
} }
// if this is the first time we're hearing about this user, sync the device list manually.
if len(event.PrevID) == 0 {
exists = false
}
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"prev_ids_exist": exists, "prev_ids_exist": exists,
"user_id": event.UserID, "user_id": event.UserID,
"device_id": event.DeviceID, "device_id": event.DeviceID,
"stream_id": event.StreamID, "stream_id": event.StreamID,
"prev_ids": event.PrevID, "prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
}).Info("DeviceListUpdater.Update") }).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users // if we haven't missed anything update the database and notify users

View file

@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
if len(dk.KeyJSON) == 0 { if len(dk.KeyJSON) == 0 {
continue // don't include blank keys continue // don't include blank keys
} }
// inject display name if known // inject display name if known (either locally or remotely)
displayName := dk.DisplayName
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
}
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) }{displayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
} }
} else { } else {
@ -261,7 +265,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
} }
} }
// TODO: set device display names when they are known
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
@ -294,6 +297,10 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
for _, key := range keys { for _, key := range keys {
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
} }
} }

View file

@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
-- required in the spec because in the event of a missed update the server fetches the entire -- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4, stream_id = $5" " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
} }
return nil return nil
} }
@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
return err return err
@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)

View file

@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id) UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id)" + " ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4, stream_id = $5" " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)
@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
} }
return nil return nil
} }
@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
return err return err

View file

@ -138,6 +138,7 @@ Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id Local device key changes get to remote servers with correct prev_id
Server correctly handles incoming m.device_list_update
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing