mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Add display_name col to store remote device names
Few other tweaks to make `Server correctly handles incoming m.device_list_update` pass.
This commit is contained in:
parent
77530ac501
commit
54ceb50e88
|
|
@ -158,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||||
}
|
}
|
||||||
|
// if this is the first time we're hearing about this user, sync the device list manually.
|
||||||
|
if len(event.PrevID) == 0 {
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"prev_ids_exist": exists,
|
"prev_ids_exist": exists,
|
||||||
"user_id": event.UserID,
|
"user_id": event.UserID,
|
||||||
"device_id": event.DeviceID,
|
"device_id": event.DeviceID,
|
||||||
"stream_id": event.StreamID,
|
"stream_id": event.StreamID,
|
||||||
"prev_ids": event.PrevID,
|
"prev_ids": event.PrevID,
|
||||||
|
"display_name": event.DeviceDisplayName,
|
||||||
}).Info("DeviceListUpdater.Update")
|
}).Info("DeviceListUpdater.Update")
|
||||||
|
|
||||||
// if we haven't missed anything update the database and notify users
|
// if we haven't missed anything update the database and notify users
|
||||||
|
|
|
||||||
|
|
@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
if len(dk.KeyJSON) == 0 {
|
if len(dk.KeyJSON) == 0 {
|
||||||
continue // don't include blank keys
|
continue // don't include blank keys
|
||||||
}
|
}
|
||||||
// inject display name if known
|
// inject display name if known (either locally or remotely)
|
||||||
|
displayName := dk.DisplayName
|
||||||
|
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||||
|
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||||
|
}
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
|
}{displayName})
|
||||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -261,7 +265,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: set device display names when they are known
|
|
||||||
|
|
||||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
||||||
|
|
@ -294,6 +297,10 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
// inject the display name
|
||||||
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||||
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
}{key.DisplayName})
|
||||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
-- required in the spec because in the event of a missed update the server fetches the entire
|
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||||
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||||
stream_id BIGINT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||||
" DO UPDATE SET key_json = $4, stream_id = $5"
|
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
var streamID int
|
var streamID int
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
var displayName sql.NullString
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
keys[i].StreamID = streamID
|
keys[i].StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
keys[i].DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
dk.DisplayName = displayName.String
|
||||||
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
|
||||||
|
|
@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
stream_id BIGINT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
UNIQUE (user_id, device_id)
|
UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT (user_id, device_id)" +
|
" ON CONFLICT (user_id, device_id)" +
|
||||||
" DO UPDATE SET key_json = $4, stream_id = $5"
|
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
dk.DisplayName = displayName.String
|
||||||
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
var streamID int
|
var streamID int
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
var displayName sql.NullString
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
keys[i].StreamID = streamID
|
keys[i].StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
keys[i].DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -138,6 +138,7 @@ Users receive device_list updates for their own devices
|
||||||
Get left notifs for other users in sync and /keys/changes when user leaves
|
Get left notifs for other users in sync and /keys/changes when user leaves
|
||||||
Local device key changes get to remote servers
|
Local device key changes get to remote servers
|
||||||
Local device key changes get to remote servers with correct prev_id
|
Local device key changes get to remote servers with correct prev_id
|
||||||
|
Server correctly handles incoming m.device_list_update
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue