diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 1ea8c40ea..7e091d97b 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -5,11 +5,14 @@ import ( "crypto/ed25519" "encoding/json" "fmt" + "net/http" + "net/http/httptest" "strings" "sync" "testing" "time" + "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -17,7 +20,10 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" @@ -362,3 +368,118 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { } } } + +func TestNotaryServer(t *testing.T) { + + // Start a server we are going to request keys from + testCases := []struct { + name string + httpBody string + pubKeyRequest *gomatrixserverlib.PublicKeyNotaryLookupRequest + validateFunc func(t *testing.T, response util.JSONResponse) + }{ + { + name: "empty httpBody", + validateFunc: func(t *testing.T, resp util.JSONResponse) { + want := util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. unexpected end of JSON input"), + } + assert.Equal(t, want, resp) + }, + }, + { + name: "valid but empty httpBody", + httpBody: "{}", + validateFunc: func(t *testing.T, resp util.JSONResponse) { + want := util.JSONResponse{ + Code: http.StatusOK, + JSON: routing.NotaryKeysResponse{}, + } + assert.Equal(t, want, resp) + }, + }, + { + name: "request all keys", + httpBody: `{"server_keys":{"servera":{}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str) + assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists()) + }, + }, + { + name: "request specific key", + httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str) + assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists()) + }, + }, + { + name: "request multiple servers", + httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}},"serverb":{"ed25519:someID":{}}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + wantServers := map[string]struct{}{ + "servera": {}, + "serverb": {}, + } + for _, js := range nk.ServerKeys { + serverName := gjson.GetBytes(js, "server_name").Str + _, ok = wantServers[serverName] + assert.True(t, ok, "unexpected servername: %s", serverName) + delete(wantServers, serverName) + assert.True(t, gjson.GetBytes(js, "verify_keys.ed25519:someID").Exists()) + } + if len(wantServers) > 0 { + t.Fatalf("expected response to also contain: %#v", wantServers) + } + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + fc := &fedClient{ + keys: map[spec.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + }{ + "servera": { + key: test.PrivateKeyA, + keyID: "ed25519:someID", + }, + "serverb": { + key: test.PrivateKeyB, + keyID: "ed25519:someID", + }, + }, + } + + fedAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, fc, nil, caches, nil, true) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.httpBody)) + req.Host = string(cfg.Global.ServerName) + + resp := routing.NotaryKeys(req, &cfg.FederationAPI, fedAPI, tc.pubKeyRequest) + // assert that we received the expected response + tc.validateFunc(t, resp) + }) + } + + }) +} diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 3d8ff2dea..fea69bd05 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -197,6 +197,10 @@ func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrix return &keys, err } +type NotaryKeysResponse struct { + ServerKeys []json.RawMessage `json:"server_keys"` +} + func NotaryKeys( httpReq *http.Request, cfg *config.FederationAPI, fsAPI federationAPI.FederationInternalAPI, @@ -217,10 +221,7 @@ func NotaryKeys( } } - var response struct { - ServerKeys []json.RawMessage `json:"server_keys"` - } - response.ServerKeys = []json.RawMessage{} + response := NotaryKeysResponse{} for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys