mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 06:11:55 -06:00
Fix notary keys requests for all keys (#3296)
This should be more spec compliant: > If no key IDs are given to be queried, the notary server should query for all keys.
This commit is contained in:
parent
edd02ec468
commit
13c5173273
|
@ -5,11 +5,14 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/routing"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -17,7 +20,10 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi"
|
"github.com/matrix-org/dendrite/federationapi"
|
||||||
"github.com/matrix-org/dendrite/federationapi/api"
|
"github.com/matrix-org/dendrite/federationapi/api"
|
||||||
|
@ -362,3 +368,126 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNotaryServer(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
httpBody string
|
||||||
|
pubKeyRequest *gomatrixserverlib.PublicKeyNotaryLookupRequest
|
||||||
|
validateFunc func(t *testing.T, response util.JSONResponse)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty httpBody",
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.Code)
|
||||||
|
nk, ok := resp.JSON.(spec.MatrixError)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, spec.ErrorBadJSON, nk.ErrCode)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid but empty httpBody",
|
||||||
|
httpBody: "{}",
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
want := util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: routing.NotaryKeysResponse{ServerKeys: []json.RawMessage{}},
|
||||||
|
}
|
||||||
|
assert.Equal(t, want, resp)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request all keys using an empty criteria",
|
||||||
|
httpBody: `{"server_keys":{"servera":{}}}`,
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
nk, ok := resp.JSON.(routing.NotaryKeysResponse)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
|
||||||
|
assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request all keys using null as the criteria",
|
||||||
|
httpBody: `{"server_keys":{"servera":null}}`,
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
nk, ok := resp.JSON.(routing.NotaryKeysResponse)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
|
||||||
|
assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request specific key",
|
||||||
|
httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}}}}`,
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
nk, ok := resp.JSON.(routing.NotaryKeysResponse)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
|
||||||
|
assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request multiple servers",
|
||||||
|
httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}},"serverb":{"ed25519:someID":{}}}}`,
|
||||||
|
validateFunc: func(t *testing.T, resp util.JSONResponse) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
nk, ok := resp.JSON.(routing.NotaryKeysResponse)
|
||||||
|
assert.True(t, ok)
|
||||||
|
wantServers := map[string]struct{}{
|
||||||
|
"servera": {},
|
||||||
|
"serverb": {},
|
||||||
|
}
|
||||||
|
for _, js := range nk.ServerKeys {
|
||||||
|
serverName := gjson.GetBytes(js, "server_name").Str
|
||||||
|
_, ok = wantServers[serverName]
|
||||||
|
assert.True(t, ok, "unexpected servername: %s", serverName)
|
||||||
|
delete(wantServers, serverName)
|
||||||
|
assert.True(t, gjson.GetBytes(js, "verify_keys.ed25519:someID").Exists())
|
||||||
|
}
|
||||||
|
if len(wantServers) > 0 {
|
||||||
|
t.Fatalf("expected response to also contain: %#v", wantServers)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
defer close()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
fc := &fedClient{
|
||||||
|
keys: map[spec.ServerName]struct {
|
||||||
|
key ed25519.PrivateKey
|
||||||
|
keyID gomatrixserverlib.KeyID
|
||||||
|
}{
|
||||||
|
"servera": {
|
||||||
|
key: test.PrivateKeyA,
|
||||||
|
keyID: "ed25519:someID",
|
||||||
|
},
|
||||||
|
"serverb": {
|
||||||
|
key: test.PrivateKeyB,
|
||||||
|
keyID: "ed25519:someID",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fedAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, fc, nil, caches, nil, true)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.httpBody))
|
||||||
|
req.Host = string(cfg.Global.ServerName)
|
||||||
|
|
||||||
|
resp := routing.NotaryKeys(req, &cfg.FederationAPI, fedAPI, tc.pubKeyRequest)
|
||||||
|
// assert that we received the expected response
|
||||||
|
tc.validateFunc(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -43,6 +43,15 @@ func (a *FederationInternalAPI) fetchServerKeysFromCache(
|
||||||
ctx context.Context, req *api.QueryServerKeysRequest,
|
ctx context.Context, req *api.QueryServerKeysRequest,
|
||||||
) ([]gomatrixserverlib.ServerKeys, error) {
|
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||||
var results []gomatrixserverlib.ServerKeys
|
var results []gomatrixserverlib.ServerKeys
|
||||||
|
|
||||||
|
// We got a request for _all_ server keys, return them.
|
||||||
|
if len(req.KeyIDToCriteria) == 0 {
|
||||||
|
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{})
|
||||||
|
if len(serverKeysResponses) == 0 {
|
||||||
|
return nil, fmt.Errorf("failed to find server key response for server %s", req.ServerName)
|
||||||
|
}
|
||||||
|
return serverKeysResponses, nil
|
||||||
|
}
|
||||||
for keyID, criteria := range req.KeyIDToCriteria {
|
for keyID, criteria := range req.KeyIDToCriteria {
|
||||||
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID})
|
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID})
|
||||||
if len(serverKeysResponses) == 0 {
|
if len(serverKeysResponses) == 0 {
|
||||||
|
|
|
@ -197,6 +197,10 @@ func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrix
|
||||||
return &keys, err
|
return &keys, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NotaryKeysResponse struct {
|
||||||
|
ServerKeys []json.RawMessage `json:"server_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
func NotaryKeys(
|
func NotaryKeys(
|
||||||
httpReq *http.Request, cfg *config.FederationAPI,
|
httpReq *http.Request, cfg *config.FederationAPI,
|
||||||
fsAPI federationAPI.FederationInternalAPI,
|
fsAPI federationAPI.FederationInternalAPI,
|
||||||
|
@ -217,10 +221,9 @@ func NotaryKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var response struct {
|
response := NotaryKeysResponse{
|
||||||
ServerKeys []json.RawMessage `json:"server_keys"`
|
ServerKeys: []json.RawMessage{},
|
||||||
}
|
}
|
||||||
response.ServerKeys = []json.RawMessage{}
|
|
||||||
|
|
||||||
for serverName, kidToCriteria := range req.ServerKeys {
|
for serverName, kidToCriteria := range req.ServerKeys {
|
||||||
var keyList []gomatrixserverlib.ServerKeys
|
var keyList []gomatrixserverlib.ServerKeys
|
||||||
|
|
Loading…
Reference in a new issue