diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index a8c09bd38..09460ea13 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -1660,6 +1660,7 @@ func TestKeys(t *testing.T) { } createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + // Start a TLSServer with our client mux srv := httptest.NewTLSServer(routers.Client) defer srv.Close() @@ -1667,11 +1668,11 @@ func TestKeys(t *testing.T) { if err != nil { t.Fatal(err) } + // Set the client so the self-signed certificate is trusted cl.Client = srv.Client() cl.DeviceID = id.DeviceID(accessTokens[alice].deviceID) cs := crypto.NewMemoryStore(nil) - oc := crypto.NewOlmMachine(cl, nil, cs, dummyStore{}) if err = oc.Load(); err != nil { t.Fatal(err) @@ -1694,30 +1695,48 @@ func TestKeys(t *testing.T) { t.Fatal(err) } + // Validate that the keys returned from the server are what the client has stored + oi := oc.OwnIdentity() + if oi.SigningKey != dev.SigningKey { + t.Fatalf("expected signing key '%s', got '%s'", oi.SigningKey, dev.SigningKey) + } + if oi.IdentityKey != dev.IdentityKey { + t.Fatalf("expected identity '%s', got '%s'", oi.IdentityKey, dev.IdentityKey) + } + // tests `/keys/signatures/upload` if err = oc.SignOwnMasterKey(); err != nil { t.Fatal(err) } - t.Logf("Dev: %#v", dev) - + // tests `/keys/claim` otks := make(map[string]map[string]string) otks[alice.ID] = map[string]string{ accessTokens[alice].deviceID: string(id.KeyAlgorithmSignedCurve25519), } - data, _ := json.Marshal(claimKeysRequest{OneTimeKeys: otks}) - req, _ := http.NewRequest(http.MethodPost, srv.URL+"/_matrix/client/v3/keys/claim", bytes.NewBuffer(data)) + data, err := json.Marshal(claimKeysRequest{OneTimeKeys: otks}) + if err != nil { + t.Fatal(err) + } + req, err := http.NewRequest(http.MethodPost, srv.URL+"/_matrix/client/v3/keys/claim", bytes.NewBuffer(data)) + if err != nil { + t.Fatal(err) + } req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) resp, err := srv.Client().Do(req) if err != nil { t.Fatal(err) } + respBody, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } - t.Logf("%s", string(respBody)) + + if !gjson.GetBytes(respBody, "one_time_keys."+alice.ID+"."+string(dev.DeviceID)).Exists() { + t.Fatalf("expected one time keys for alice, but didn't find any: %s", string(respBody)) + } }) }