diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 98b06c17a..9dfb87ab3 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -74,19 +74,7 @@ type Device interface { RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) } -type Database interface { - Account - AccountData - Device - Profile - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) - RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) - GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) - - // Key backups +type KeyBackup interface { CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) @@ -94,6 +82,20 @@ type Database interface { UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) +} + +type Database interface { + Account + AccountData + Device + KeyBackup + Profile + SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) + GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) // CreateLoginToken generates a token, stores and returns it. The lifetime is // determined by the loginTokenLifetime given to the Database constructor. diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 3c3457bc1..54cf7def8 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -71,13 +71,16 @@ func Test_Accounts(t *testing.T) { alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) - + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount - accGet, err := db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + var accGet *api.Account + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability @@ -91,7 +94,7 @@ func Test_Accounts(t *testing.T) { // get guest account numeric aliceLocalpart first, err := db.GetNewNumericLocalpart(ctx) - assert.NoError(t, err) + assert.NoError(t, err, "failed to get new numeric localpart") // SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead _, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") @@ -103,6 +106,7 @@ func Test_Accounts(t *testing.T) { err = db.SetPassword(ctx, aliceLocalpart, "newPassword") assert.NoError(t, err, "failed to update password") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account @@ -159,6 +163,86 @@ func Test_Devices(t *testing.T) { devices2, err := db.GetDevicesByID(ctx, deviceIDs) assert.NoError(t, err, "unable to get devices by id") assert.Equal(t, devices, devices2) - t.Logf("%+v", devices) + }) +} + +func Test_KeyBackup(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + //localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + //assert.NoError(t, err) + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + wantAuthData := json.RawMessage("my auth data") + wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData) + assert.NoError(t, err, "unable to create key backup") + // get key backup by version + gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion) + assert.NoError(t, err, "unable to get key backup") + assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") + assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") + assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") + + // get any key backup + gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "") + assert.NoError(t, err, "unable to get key backup") + assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") + assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") + assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") + + err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data")) + assert.NoError(t, err, "unable to update key backup auth data") + + uploads := []api.InternalKeyBackupSession{ + { + KeyBackupSession: api.KeyBackupSession{ + IsVerified: true, + SessionData: wantAuthData, + }, + RoomID: room.ID, + SessionID: "1", + }, + { + KeyBackupSession: api.KeyBackupSession{}, + RoomID: room.ID, + SessionID: "2", + }, + } + count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads) + assert.NoError(t, err, "unable to upsert backup keys") + assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") + + // do it again to update a key + uploads[1].IsVerified = true + count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:]) + assert.NoError(t, err, "unable to upsert backup keys") + assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") + + // get backup keys by session id + gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1") + assert.NoError(t, err, "unable to get backup keys") + assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) + + // get backup keys by room id + gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "") + assert.NoError(t, err, "unable to get backup keys") + assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) + + gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID) + assert.NoError(t, err, "unable to get backup keys count") + assert.Equal(t, count, gotCount, "unexpected backup count") + + // finally delete a key + exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion) + assert.NoError(t, err, "unable to delete key backup") + assert.True(t, exists) + + // this key should not exist + exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3") + assert.NoError(t, err, "unable to delete key backup") + assert.False(t, exists) }) }