Add simple test for one time keys (#3239)
This commit is contained in:
parent
e02a7948d8
commit
a0375d41fb
|
@ -1,6 +1,7 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOneTimeKeys(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
userID := "@alice:localhost"
|
||||
deviceID := "alice_device"
|
||||
otk := api.OneTimeKeys{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||
}
|
||||
|
||||
// Add a one time key to the DB
|
||||
_, err := db.StoreOneTimeKeys(ctx, otk)
|
||||
MustNotError(t, err)
|
||||
|
||||
// Check the count of one time keys is correct
|
||||
count, err := db.OneTimeKeysCount(ctx, userID, deviceID)
|
||||
MustNotError(t, err)
|
||||
if count.KeyCount["curve25519"] != 1 {
|
||||
t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"])
|
||||
}
|
||||
|
||||
// Check the actual key contents are correct
|
||||
keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"})
|
||||
MustNotError(t, err)
|
||||
keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON()
|
||||
MustNotError(t, err)
|
||||
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||
t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"])
|
||||
}
|
||||
|
||||
// Claim a one time key from the database. This should remove it from the database.
|
||||
claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
|
||||
MustNotError(t, err)
|
||||
|
||||
// Check the claimed key contents are correct
|
||||
if !reflect.DeepEqual(claimedKeys[0], otk) {
|
||||
t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0])
|
||||
}
|
||||
|
||||
// Check the count of one time keys is now zero
|
||||
count, err = db.OneTimeKeysCount(ctx, userID, deviceID)
|
||||
MustNotError(t, err)
|
||||
if count.KeyCount["curve25519"] != 0 {
|
||||
t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue