Add simple test for one time keys (#3239)
This commit is contained in:
parent
e02a7948d8
commit
a0375d41fb
|
@ -1,6 +1,7 @@
|
||||||
package storage_test
|
package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOneTimeKeys(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||||
|
defer clean()
|
||||||
|
userID := "@alice:localhost"
|
||||||
|
deviceID := "alice_device"
|
||||||
|
otk := api.OneTimeKeys{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a one time key to the DB
|
||||||
|
_, err := db.StoreOneTimeKeys(ctx, otk)
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
// Check the count of one time keys is correct
|
||||||
|
count, err := db.OneTimeKeysCount(ctx, userID, deviceID)
|
||||||
|
MustNotError(t, err)
|
||||||
|
if count.KeyCount["curve25519"] != 1 {
|
||||||
|
t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the actual key contents are correct
|
||||||
|
keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"})
|
||||||
|
MustNotError(t, err)
|
||||||
|
keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON()
|
||||||
|
MustNotError(t, err)
|
||||||
|
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||||
|
t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claim a one time key from the database. This should remove it from the database.
|
||||||
|
claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
// Check the claimed key contents are correct
|
||||||
|
if !reflect.DeepEqual(claimedKeys[0], otk) {
|
||||||
|
t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the count of one time keys is now zero
|
||||||
|
count, err = db.OneTimeKeysCount(ctx, userID, deviceID)
|
||||||
|
MustNotError(t, err)
|
||||||
|
if count.KeyCount["curve25519"] != 0 {
|
||||||
|
t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue