Add tests for keybackup

This commit is contained in:
Till Faelligen 2022-04-27 09:33:42 +02:00
parent cf5acdd16f
commit 8be56eeb65
2 changed files with 103 additions and 17 deletions

View file

@ -74,19 +74,7 @@ type Device interface {
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}
type Database interface {
Account
AccountData
Device
Profile
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
// Key backups
type KeyBackup interface {
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
@ -94,6 +82,20 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
}
type Database interface {
Account
AccountData
Device
KeyBackup
Profile
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.

View file

@ -71,13 +71,16 @@ func Test_Accounts(t *testing.T) {
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
accGet, err := db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet)
// check account availability
@ -91,7 +94,7 @@ func Test_Accounts(t *testing.T) {
// get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err)
assert.NoError(t, err, "failed to get new numeric localpart")
// SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead
_, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
@ -103,6 +106,7 @@ func Test_Accounts(t *testing.T) {
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet)
// deactivate account
@ -159,6 +163,86 @@ func Test_Devices(t *testing.T) {
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
assert.NoError(t, err, "unable to get devices by id")
assert.Equal(t, devices, devices2)
t.Logf("%+v", devices)
})
}
func Test_KeyBackup(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
//localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
//assert.NoError(t, err)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
assert.NoError(t, err, "unable to create key backup")
// get key backup by version
gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
assert.NoError(t, err, "unable to get key backup")
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
// get any key backup
gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
assert.NoError(t, err, "unable to get key backup")
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
assert.NoError(t, err, "unable to update key backup auth data")
uploads := []api.InternalKeyBackupSession{
{
KeyBackupSession: api.KeyBackupSession{
IsVerified: true,
SessionData: wantAuthData,
},
RoomID: room.ID,
SessionID: "1",
},
{
KeyBackupSession: api.KeyBackupSession{},
RoomID: room.ID,
SessionID: "2",
},
}
count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
assert.NoError(t, err, "unable to upsert backup keys")
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
// do it again to update a key
uploads[1].IsVerified = true
count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
assert.NoError(t, err, "unable to upsert backup keys")
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
// get backup keys by session id
gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
assert.NoError(t, err, "unable to get backup keys")
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
// get backup keys by room id
gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
assert.NoError(t, err, "unable to get backup keys")
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
assert.NoError(t, err, "unable to get backup keys count")
assert.Equal(t, count, gotCount, "unexpected backup count")
// finally delete a key
exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
assert.NoError(t, err, "unable to delete key backup")
assert.True(t, exists)
// this key should not exist
exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
assert.NoError(t, err, "unable to delete key backup")
assert.False(t, exists)
})
}