package storage_test import ( "bytes" "context" "encoding/json" "fmt" "reflect" "sync" "testing" "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" ) const loginTokenLifetime = time.Minute var ( openIDLifetimeMS = time.Minute.Milliseconds() ctx = context.Background() ) func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() } } // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) events := room.Events() contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom) assert.NoError(t, err, "unable to save account data") contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal) assert.NoError(t, err, "unable to save account data") accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read") assert.NoError(t, err, "unable to get account data by type") assert.Equal(t, contentRoom, accountData) globalData, roomData, err := db.GetAccountData(ctx, localpart, domain) assert.NoError(t, err) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) }) } // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart first, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") second, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to update password") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.Error(t, err, "expected an error, got none") _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain) assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart // if there's already a user without a localpart in the database _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // test getting a numeric localpart, with an existing user without a localpart _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // Now try to create a new guest user _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) }) } func Test_Devices(t *testing.T) { alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken) assert.NoError(t, err, "unable to get device by access token") assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields // create a device without existing device ID accessToken = util.RandomString(16) deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) for _, dev := range devices { deviceIDs = append(deviceIDs, dev.ID) } devices2, err := db.GetDevicesByID(ctx, deviceIDs) assert.NoError(t, err, "unable to get devices by id") assert.ElementsMatch(t, devices, devices2) // Update device newName := "new display name" err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP) assert.Greater(t, gotDevice.LastSeenTS, updatedAfterTimestamp) // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) err = db.RemoveDevices(ctx, localpart, domain, deviceIDs) assert.NoError(t, err, "unable to remove devices") devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) }) } func Test_KeyBackup(t *testing.T) { alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData) assert.NoError(t, err, "unable to create key backup") // get key backup by version gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion) assert.NoError(t, err, "unable to get key backup") assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") // get any key backup gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "") assert.NoError(t, err, "unable to get key backup") assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data")) assert.NoError(t, err, "unable to update key backup auth data") uploads := []api.InternalKeyBackupSession{ { KeyBackupSession: api.KeyBackupSession{ IsVerified: true, SessionData: wantAuthData, }, RoomID: room.ID, SessionID: "1", }, { KeyBackupSession: api.KeyBackupSession{}, RoomID: room.ID, SessionID: "2", }, } count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads) assert.NoError(t, err, "unable to upsert backup keys") assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") // do it again to update a key uploads[1].IsVerified = true count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:]) assert.NoError(t, err, "unable to upsert backup keys") assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") // get backup keys by session id gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1") assert.NoError(t, err, "unable to get backup keys") assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) // get backup keys by room id gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "") assert.NoError(t, err, "unable to get backup keys") assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID) assert.NoError(t, err, "unable to get backup keys count") assert.Equal(t, count, gotCount, "unexpected backup count") // finally delete a key exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion) assert.NoError(t, err, "unable to delete key backup") assert.True(t, exists) // this key should not exist exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3") assert.NoError(t, err, "unable to delete key backup") assert.False(t, exists) }) } func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token wantLoginToken := &api.LoginTokenData{UserID: alice.ID} gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken) assert.NoError(t, err, "unable to create login token") assert.NotNil(t, gotMetadata) assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime)) // get the new token gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) assert.NoError(t, err, "unable to get login token") assert.NotNil(t, gotLoginToken) assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token") // remove the login token again err = db.RemoveLoginToken(ctx, gotMetadata.Token) assert.NoError(t, err, "unable to remove login token") // check if the token was actually deleted _, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) assert.Error(t, err, "expected an error, but got none") }) } func Test_OpenID(t *testing.T) { alice := test.NewUser(t) token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS expires, err := db.CreateOpenIDToken(ctx, token, alice.ID) assert.NoError(t, err, "unable to create OpenID token") assert.InDelta(t, expiresAtMS, expires, 2) // 2ms leeway attributes, err := db.GetOpenIDTokenAttributes(ctx, token) assert.NoError(t, err, "unable to get OpenID token attributes") assert.Equal(t, alice.ID, attributes.UserID) assert.InDelta(t, expiresAtMS, attributes.ExpiresAtMS, 2) // 2ms leeway }) } func Test_Profile(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") wantProfile := &authtypes.Profile{ Localpart: aliceLocalpart, ServerName: string(aliceDomain), } assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice") assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") assert.True(t, changed) wantProfile.AvatarURL = "mxc://aliceAvatar" gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.True(t, changed) // Setting the same avatar again doesn't change anything wantProfile.AvatarURL = "mxc://aliceAvatar" gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.False(t, changed) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) assert.NoError(t, err, "unable to search profiles") assert.Equal(t, 1, len(searchRes)) assert.Equal(t, *wantProfile, searchRes[0]) }) } func Test_Pusher(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) var pushKeys []string var gotPushers []api.Pusher for i := 0; i < 2; i++ { pushKey := util.RandomString(8) wantPusher := api.Pusher{ PushKey: pushKey, Kind: api.HTTPKind, AppID: appID, AppDisplayName: util.RandomString(8), DeviceDisplayName: util.RandomString(8), ProfileTag: util.RandomString(8), Language: util.RandomString(2), } err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to upsert pusher") // check it was actually persisted gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, wantPusher, gotPushers[i]) pushKeys = append(pushKeys, pushKey) } // remove single pusher err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to remove pusher") gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 1, len(gotPushers)) // remove last pusher err = db.RemovePushers(ctx, appID, pushKeys[1]) assert.NoError(t, err, "unable to remove pusher") gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 0, len(gotPushers)) }) } func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) assert.Equal(t, aliceDomain, gotDomain) threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ Address: threePID, Medium: medium, }, threepids[0]) // remove threepid association err = db.RemoveThreePIDAssociation(ctx, threePID, medium) assert.NoError(t, err, "unexpected error") // verify it was deleted threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) } func Test_Notification(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { eventID := util.RandomString(16) roomID := room.ID ts := time.Now() if i > 5 { roomID = room2.ID // create some old notifications to test DeleteOldNotifications ts = ts.AddDate(0, -2, 0) } notification := &api.Notification{ Actions: []*pushrules.Action{ {}, }, Event: synctypes.ClientEvent{ Content: spec.RawJSON("{}"), }, Read: false, RoomID: roomID, TS: spec.AsTimestamp(ts), } err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } // get notifications count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") assert.Equal(t, int64(10), count) notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count) assert.Equal(t, 10, len(notifs)) // ... for a specific room total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(4), total) // mark notification as read affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(2), total) // delete old notifications err = db.DeleteOldNotifications(ctx) assert.NoError(t, err) // this should now return 0 notifications total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(0), total) }) } func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) db, err := storage.NewKeyDatabase(cm, &cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) } return db, close } func MustNotError(t *testing.T, err error) { t.Helper() if err == nil { return } t.Fatalf("operation failed: %s", err) } func TestKeyChanges(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() _, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeIDC { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) } if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } func TestKeyChangesNoDupes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) if deviceChangeIDA == deviceChangeIDB { t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) } deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeID { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) } if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } func TestKeyChangesUpperLimit(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err) _, err = db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeIDB { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) } if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } var dbLock sync.Mutex var deviceArray = []string{"AAA", "another_device"} // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // and that they are returned correctly when querying for device keys. func TestDeviceKeysStreamIDGeneration(t *testing.T) { var err error test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() alice := "@alice:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{ { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 1 }, { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: bob, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 1 as this is a different user }, { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "another_device", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 2 as this is a 2nd device key }, } MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 1 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) } if msgs[1].StreamID != 1 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) } if msgs[2].StreamID != 2 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) } // updating a device sets the next stream ID for that user msgs = []api.DeviceMessage{ { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v2"}`), }, // StreamID: 3 }, } MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 3 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) } dbLock.Lock() defer dbLock.Unlock() // Querying for device keys returns the latest stream IDs msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } wantStreamIDs := map[string]int64{ "AAA": 3, "another_device": 2, } if len(msgs) != len(wantStreamIDs) { t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) } for _, m := range msgs { if m.StreamID != wantStreamIDs[m.DeviceID] { t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) } } }) } func TestOneTimeCryptoIDs(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() userID := "@alice:localhost" otk := api.OneTimeCryptoIDs{ UserID: userID, KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)}, } // Add a one time pseudoID to the DB _, err := db.StoreOneTimeCryptoIDs(ctx, otk) MustNotError(t, err) // Check the count of one time pseudoIDs is correct count, err := db.OneTimeCryptoIDsCount(ctx, userID) MustNotError(t, err) if count.KeyCount["pseudoid_curve25519"] != 1 { t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"]) } // Check the actual pseudoid contents are correct keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"}) MustNotError(t, err) keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON() MustNotError(t, err) if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) { t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"]) } }) }