mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Allow including empty entries so we can clean them up
This commit is contained in:
parent
1357ae5248
commit
5428aae442
|
|
@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
||||||
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
|
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||||
|
|
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
domain := string(serverName)
|
domain := string(serverName)
|
||||||
// query local devices
|
// query local devices
|
||||||
if serverName == a.ThisServer {
|
if serverName == a.ThisServer {
|
||||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||||
|
|
@ -520,7 +520,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
||||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||||
) error {
|
) error {
|
||||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||||
|
|
@ -564,19 +564,23 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
|
||||||
|
for _, key := range uapidevices.Devices {
|
||||||
|
existingDeviceMap[key.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithField("user_id", req.UserID).Infof("XXX: Existing devices: %+v", existingDeviceMap)
|
||||||
|
|
||||||
// Get all of the user existing device keys so we can check for changes.
|
// Get all of the user existing device keys so we can check for changes.
|
||||||
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
|
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
existingDeviceMap := make(map[string]struct{}, len(existingKeys))
|
|
||||||
for _, key := range uapidevices.Devices {
|
logrus.WithField("user_id", req.UserID).Infof("XXX: Existing keys: %+v", existingKeys)
|
||||||
existingDeviceMap[key.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work out whether we have device keys in the keyserver for devices that
|
// Work out whether we have device keys in the keyserver for devices that
|
||||||
// no longer exist in the user API. This is mostly an exercise to ensure
|
// no longer exist in the user API. This is mostly an exercise to ensure
|
||||||
|
|
@ -587,13 +591,18 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logrus.WithField("user_id", req.UserID).Infof("XXX: Clean keys: %+v", toClean)
|
||||||
|
|
||||||
if len(toClean) > 0 {
|
if len(toClean) > 0 {
|
||||||
|
logrus.WithField("user_id", req.UserID).Infof("Cleaning up %d stale device keys for user", len(toClean))
|
||||||
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale device keys for user", len(toClean))
|
||||||
}
|
}
|
||||||
|
|
||||||
var keysToStore []api.DeviceMessage
|
var keysToStore []api.DeviceMessage
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ type Database interface {
|
||||||
|
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||||
|
|
||||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||||
// cross-signing signatures relating to that device.
|
// cross-signing signatures relating to that device.
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
|
||||||
|
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||||
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
|
@ -73,6 +76,7 @@ type deviceKeysStatements struct {
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
countStreamIDsForUserStmt *sql.Stmt
|
countStreamIDsForUserStmt *sql.Stmt
|
||||||
deleteDeviceKeysStmt *sql.Stmt
|
deleteDeviceKeysStmt *sql.Stmt
|
||||||
|
|
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
var stmt *sql.Stmt
|
||||||
|
if includeEmpty {
|
||||||
|
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||||
|
} else {
|
||||||
|
stmt = s.selectBatchDeviceKeysStmt
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
|
||||||
|
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||||
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
|
@ -69,6 +72,7 @@ type deviceKeysStatements struct {
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
deleteDeviceKeysStmt *sql.Stmt
|
deleteDeviceKeysStmt *sql.Stmt
|
||||||
deleteAllDeviceKeysStmt *sql.Stmt
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
|
|
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
var stmt *sql.Stmt
|
||||||
|
if includeEmpty {
|
||||||
|
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||||
|
} else {
|
||||||
|
stmt = s.selectBatchDeviceKeysStmt
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Querying for device keys returns the latest stream IDs
|
// Querying for device keys returns the latest stream IDs
|
||||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ type DeviceKeys interface {
|
||||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue