From a0375d41fbbdabd98df743d2e7fa77b4d0c44d4b Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 25 Oct 2023 08:13:18 +0000 Subject: [PATCH] Add simple test for one time keys (#3239) --- userapi/storage/storage_test.go | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a46ee9ebb..5a789dfd7 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } }) } + +func TestOneTimeKeys(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + userID := "@alice:localhost" + deviceID := "alice_device" + otk := api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)}, + } + + // Add a one time key to the DB + _, err := db.StoreOneTimeKeys(ctx, otk) + MustNotError(t, err) + + // Check the count of one time keys is correct + count, err := db.OneTimeKeysCount(ctx, userID, deviceID) + MustNotError(t, err) + if count.KeyCount["curve25519"] != 1 { + t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"]) + } + + // Check the actual key contents are correct + keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"}) + MustNotError(t, err) + keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON() + MustNotError(t, err) + if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) { + t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"]) + } + + // Claim a one time key from the database. This should remove it from the database. + claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}}) + MustNotError(t, err) + + // Check the claimed key contents are correct + if !reflect.DeepEqual(claimedKeys[0], otk) { + t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0]) + } + + // Check the count of one time keys is now zero + count, err = db.OneTimeKeysCount(ctx, userID, deviceID) + MustNotError(t, err) + if count.KeyCount["curve25519"] != 0 { + t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"]) + } + }) +}