mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Add tests for keybackup
This commit is contained in:
parent
cf5acdd16f
commit
8be56eeb65
|
|
@ -74,19 +74,7 @@ type Device interface {
|
|||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
Profile
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
|
||||
// Key backups
|
||||
type KeyBackup interface {
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
|
|
@ -94,6 +82,20 @@ type Database interface {
|
|||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
KeyBackup
|
||||
Profile
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
|
|
|
|||
|
|
@ -71,13 +71,16 @@ func Test_Accounts(t *testing.T) {
|
|||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
accGet, err := db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
var accGet *api.Account
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
assert.NoError(t, err, "failed to get account by password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to get account by localpart")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
|
|
@ -91,7 +94,7 @@ func Test_Accounts(t *testing.T) {
|
|||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead
|
||||
_, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
|
@ -103,6 +106,7 @@ func Test_Accounts(t *testing.T) {
|
|||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.NoError(t, err, "failed to get account by new password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
|
|
@ -159,6 +163,86 @@ func Test_Devices(t *testing.T) {
|
|||
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
||||
assert.NoError(t, err, "unable to get devices by id")
|
||||
assert.Equal(t, devices, devices2)
|
||||
t.Logf("%+v", devices)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_KeyBackup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
//localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
//assert.NoError(t, err)
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantAuthData := json.RawMessage("my auth data")
|
||||
wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
|
||||
assert.NoError(t, err, "unable to create key backup")
|
||||
// get key backup by version
|
||||
gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
|
||||
assert.NoError(t, err, "unable to get key backup")
|
||||
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||
|
||||
// get any key backup
|
||||
gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
|
||||
assert.NoError(t, err, "unable to get key backup")
|
||||
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||
|
||||
err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
|
||||
assert.NoError(t, err, "unable to update key backup auth data")
|
||||
|
||||
uploads := []api.InternalKeyBackupSession{
|
||||
{
|
||||
KeyBackupSession: api.KeyBackupSession{
|
||||
IsVerified: true,
|
||||
SessionData: wantAuthData,
|
||||
},
|
||||
RoomID: room.ID,
|
||||
SessionID: "1",
|
||||
},
|
||||
{
|
||||
KeyBackupSession: api.KeyBackupSession{},
|
||||
RoomID: room.ID,
|
||||
SessionID: "2",
|
||||
},
|
||||
}
|
||||
count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
|
||||
assert.NoError(t, err, "unable to upsert backup keys")
|
||||
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||
|
||||
// do it again to update a key
|
||||
uploads[1].IsVerified = true
|
||||
count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
|
||||
assert.NoError(t, err, "unable to upsert backup keys")
|
||||
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||
|
||||
// get backup keys by session id
|
||||
gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
|
||||
assert.NoError(t, err, "unable to get backup keys")
|
||||
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||
|
||||
// get backup keys by room id
|
||||
gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
|
||||
assert.NoError(t, err, "unable to get backup keys")
|
||||
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||
|
||||
gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
|
||||
assert.NoError(t, err, "unable to get backup keys count")
|
||||
assert.Equal(t, count, gotCount, "unexpected backup count")
|
||||
|
||||
// finally delete a key
|
||||
exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
|
||||
assert.NoError(t, err, "unable to delete key backup")
|
||||
assert.True(t, exists)
|
||||
|
||||
// this key should not exist
|
||||
exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
|
||||
assert.NoError(t, err, "unable to delete key backup")
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue