From fb56bbf0b7d4b21da3f55b066e71d24bf4599887 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Mon, 3 Aug 2020 17:07:06 +0100 Subject: [PATCH] Generate stream IDs for locally uploaded device keys (#1236) * Breaking: add stream_id to keyserver_device_keys table * Add tests for stream ID generation * Fix whitelist --- keyserver/api/api.go | 17 ++++ keyserver/internal/internal.go | 36 +++++--- keyserver/producers/keychange.go | 2 +- keyserver/storage/interface.go | 9 +- .../storage/postgres/device_keys_table.go | 84 ++++++++++++------- keyserver/storage/shared/storage.go | 29 ++++++- .../storage/sqlite3/device_keys_table.go | 80 +++++++++++------- keyserver/storage/storage_test.go | 82 ++++++++++++++++++ keyserver/storage/tables/interface.go | 7 +- syncapi/consumers/keychange.go | 2 +- sytest-whitelist | 1 + 11 files changed, 265 insertions(+), 84 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index eb2f9e24a..080d0e5fd 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -43,6 +43,13 @@ func (k *KeyError) Error() string { return k.Err } +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + DeviceKeys + // A monotonically increasing number which represents device changes for this user. + StreamID int +} + // DeviceKeys represents a set of device keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type DeviceKeys struct { @@ -50,10 +57,20 @@ type DeviceKeys struct { UserID string // The device ID of this device DeviceID string + // The device display name + DisplayName string // The raw device key JSON KeyJSON []byte } +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { + return DeviceMessage{ + DeviceKeys: *k, + StreamID: streamID, + } +} + // OneTimeKeys represents a set of one-time keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type OneTimeKeys struct { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 3c8dff847..9027cbf4f 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) - a.uploadDeviceKeys(ctx, req, res) + a.uploadLocalDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } @@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - var keysToStore []api.DeviceKeys +func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if serverName != a.ThisServer { + continue // ignore remote users + } if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue } @@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU } // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceKeys, len(keysToStore)) + existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { - existingKeys[i] = api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, + existingKeys[i] = api.DeviceMessage{ + DeviceKeys: api.DeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, + }, } } if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { @@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU return } // store the device keys and emit changes - if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { + err := a.DB.StoreDeviceKeys(ctx, keysToStore) + if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err := a.emitDeviceKeyChanges(existingKeys, keysToStore) + err = a.emitDeviceKeyChanges(existingKeys, keysToStore) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { +func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error { // find keys in new that are not in existing - var keysAdded []api.DeviceKeys + var keysAdded []api.DeviceMessage for _, newKey := range new { exists := false for _, existingKey := range existing { diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 6035b67bd..99629b42e 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 { } // ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { for _, key := range keys { var m sarama.ProducerMessage diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0e0158e58..11284d86b 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -32,17 +32,18 @@ type Database interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d915246c7..e1b4e9475 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err @@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 44cb0cc25..e78ee9433 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) +func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = int(streamID) + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 69fe7a6e4..9f70885ad 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true @@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return result, rows.Err() } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 66f6930f9..b3e45e6cc 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/keyserver/api" ) var ctx = context.Background() @@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } } + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index c6e43be45..65da3310c 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -32,9 +32,10 @@ type OneTimeKeys interface { } type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error - InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } type KeyChanges interface { diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 35978be71..e14d2223e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er defer func() { s.updateOffset(msg) }() - var output api.DeviceKeys + var output api.DeviceMessage if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server") diff --git a/sytest-whitelist b/sytest-whitelist index 16a71c648..a1d2e437c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync Sync can be polled for updates Sync is woken up for leaves Newly left rooms appear in the leave section of incremental sync +Rooms can be created with an initial invite list (SYN-205) We should see our own leave event, even if history_visibility is restricted (SYN-662) We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462) Newly left rooms appear in the leave section of gapped sync