diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index af2641cce..e1b4e9475 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -102,8 +102,16 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) { - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } return } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 53d16f96c..e78ee9433 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -59,7 +59,7 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage if err != nil { return err } - userIDToStreamID[userID] = streamID + userIDToStreamID[userID] = int(streamID) } // set the stream IDs for each key for i := range keys { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index dd7cf9faa..9f70885ad 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -127,8 +127,16 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) { - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } return } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 66f6930f9..b3e45e6cc 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/keyserver/api" ) var ctx = context.Background() @@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } } + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index d7c06a764..65da3310c 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -34,7 +34,7 @@ type OneTimeKeys interface { type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) }