Add simple test for one time keys (#3239)

This commit is contained in:
devonh 2023-10-25 08:13:18 +00:00 committed by GitHub
parent e02a7948d8
commit a0375d41fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,7 @@
package storage_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
})
}
func TestOneTimeKeys(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
deviceID := "alice_device"
otk := api.OneTimeKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
// Add a one time key to the DB
_, err := db.StoreOneTimeKeys(ctx, otk)
MustNotError(t, err)
// Check the count of one time keys is correct
count, err := db.OneTimeKeysCount(ctx, userID, deviceID)
MustNotError(t, err)
if count.KeyCount["curve25519"] != 1 {
t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"])
}
// Check the actual key contents are correct
keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"})
MustNotError(t, err)
keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON()
MustNotError(t, err)
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"])
}
// Claim a one time key from the database. This should remove it from the database.
claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
MustNotError(t, err)
// Check the claimed key contents are correct
if !reflect.DeepEqual(claimedKeys[0], otk) {
t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0])
}
// Check the count of one time keys is now zero
count, err = db.OneTimeKeysCount(ctx, userID, deviceID)
MustNotError(t, err)
if count.KeyCount["curve25519"] != 0 {
t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"])
}
})
}