Update comments, add fixes from forward-merge

This commit is contained in:
Neil Alexander 2020-06-15 15:23:51 +01:00
parent bf6faa4290
commit 5c7a4571de
5 changed files with 63 additions and 58 deletions

View file

@ -66,7 +66,7 @@ func Setup(
FsAPI: fsAPI, FsAPI: fsAPI,
} }
localKeys := internal.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
request := &serverkeyAPI.QueryLocalKeysRequest{} request := &serverkeyAPI.QueryLocalKeysRequest{}
response := &serverkeyAPI.QueryLocalKeysResponse{} response := &serverkeyAPI.QueryLocalKeysResponse{}
if err := skAPI.QueryLocalKeys(req.Context(), request, response); err != nil { if err := skAPI.QueryLocalKeys(req.Context(), request, response); err != nil {

View file

@ -2,7 +2,6 @@ package caching
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -25,7 +24,6 @@ func (c Caches) GetServerKey(
timestamp gomatrixserverlib.Timestamp, timestamp gomatrixserverlib.Timestamp,
) (gomatrixserverlib.PublicKeyLookupResult, bool) { ) (gomatrixserverlib.PublicKeyLookupResult, bool) {
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
now := gomatrixserverlib.AsTimestamp(time.Now())
val, found := c.ServerKeys.Get(key) val, found := c.ServerKeys.Get(key)
if found && val != nil { if found && val != nil {
if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok {

View file

@ -141,5 +141,5 @@ func (h *httpServerKeyInternalAPI) QueryLocalKeys(
defer span.Finish() defer span.Finish()
apiURL := h.serverKeyAPIURL + ServerKeyQueryLocalKeysPath apiURL := h.serverKeyAPIURL + ServerKeyQueryLocalKeysPath
return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }

View file

@ -13,7 +13,7 @@ import (
func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache caching.ServerKeyCache) { func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache caching.ServerKeyCache) {
internalAPIMux.Handle(ServerKeyQueryLocalKeysPath, internalAPIMux.Handle(ServerKeyQueryLocalKeysPath,
internal.MakeInternalAPI("queryLocalKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryLocalKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryLocalKeysRequest{} request := api.QueryLocalKeysRequest{}
response := api.QueryLocalKeysResponse{} response := api.QueryLocalKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {

View file

@ -29,6 +29,10 @@ type server struct {
} }
func (s *server) renew() { func (s *server) renew() {
// This updates the validity period to be an hour in the
// future, which is particularly useful in server A and
// server C's cases which have validity either as now or
// in the past.
s.validity = time.Hour s.validity = time.Hour
s.config.Matrix.KeyValidityPeriod = s.validity s.config.Matrix.KeyValidityPeriod = s.validity
} }
@ -47,17 +51,23 @@ var servers = map[string]*server{
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Set up the server key API for each "server" that we
// will use in our tests.
for _, s := range servers { for _, s := range servers {
// Generate a new key.
_, testPriv, err := ed25519.GenerateKey(nil) _, testPriv, err := ed25519.GenerateKey(nil)
if err != nil { if err != nil {
panic("can't generate identity key: " + err.Error()) panic("can't generate identity key: " + err.Error())
} }
// Create a new cache but don't enable prometheus!
s.cache, err = caching.NewInMemoryLRUCache(false) s.cache, err = caching.NewInMemoryLRUCache(false)
if err != nil { if err != nil {
panic("can't create cache: " + err.Error()) panic("can't create cache: " + err.Error())
} }
// Draw up just enough Dendrite config for the server key
// API to work.
s.config = &config.Dendrite{} s.config = &config.Dendrite{}
s.config.SetDefaults() s.config.SetDefaults()
s.config.Matrix.KeyValidityPeriod = s.validity s.config.Matrix.KeyValidityPeriod = s.validity
@ -66,38 +76,51 @@ func TestMain(m *testing.M) {
s.config.Matrix.KeyID = serverKeyID s.config.Matrix.KeyID = serverKeyID
s.config.Database.ServerKey = config.DataSource("file::memory:") s.config.Database.ServerKey = config.DataSource("file::memory:")
// Create a transport which redirects federation requests to
// the mock round tripper. Since we're not *really* listening for
// federation requests then this will return the key instead.
transport := &http.Transport{} transport := &http.Transport{}
transport.RegisterProtocol("matrix", &MockRoundTripper{}) transport.RegisterProtocol("matrix", &MockRoundTripper{})
// Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClientWithTransport( s.fedclient = gomatrixserverlib.NewFederationClientWithTransport(
s.config.Matrix.ServerName, serverKeyID, testPriv, transport, s.config.Matrix.ServerName, serverKeyID, testPriv, transport,
) )
// Finally, build the server key APIs.
s.api = NewInternalAPI(s.config, s.fedclient, s.cache) s.api = NewInternalAPI(s.config, s.fedclient, s.cache)
} }
// Now that we have built our server key APIs, start the
// rest of the tests.
os.Exit(m.Run()) os.Exit(m.Run())
} }
type MockRoundTripper struct{} type MockRoundTripper struct{}
func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
// Check if the request is looking for keys from a server that
// we know about in the test. The only reason this should go wrong
// is if the test is broken.
s, ok := servers[req.Host] s, ok := servers[req.Host]
if !ok { if !ok {
return nil, fmt.Errorf("server not known: %s", req.Host) return nil, fmt.Errorf("server not known: %s", req.Host)
} }
// Query the local keys for the server in question.
request := &api.QueryLocalKeysRequest{} request := &api.QueryLocalKeysRequest{}
response := &api.QueryLocalKeysResponse{} response := &api.QueryLocalKeysResponse{}
if err = s.api.QueryLocalKeys(context.Background(), request, response); err != nil { if err = s.api.QueryLocalKeys(context.Background(), request, response); err != nil {
return nil, err return nil, err
} }
// Make a nice JSON response out of it.
body, err := json.MarshalIndent(response.ServerKeys, "", " ") body, err := json.MarshalIndent(response.ServerKeys, "", " ")
if err != nil { if err != nil {
return nil, err return nil, err
} }
// And respond.
res = &http.Response{ res = &http.Response{
StatusCode: 200, StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(body)), Body: ioutil.NopCloser(bytes.NewReader(body)),
@ -106,10 +129,8 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
} }
func TestServersRequestOwnKeys(t *testing.T) { func TestServersRequestOwnKeys(t *testing.T) {
/* // Each server will request its own keys. There's no reason
Each server will request its own keys. There's no reason // for this to fail as each server should know its own keys.
for this to fail as each server should know its own keys.
*/
for name, s := range servers { for name, s := range servers {
req := gomatrixserverlib.PublicKeyLookupRequest{ req := gomatrixserverlib.PublicKeyLookupRequest{
@ -133,11 +154,9 @@ func TestServersRequestOwnKeys(t *testing.T) {
} }
func TestCachingBehaviour(t *testing.T) { func TestCachingBehaviour(t *testing.T) {
/* // Server A will request Server B's key, which has a validity
Server A will request Server B's key, which has a validity // period of an hour from now. We should retrieve the key and
period of an hour from now. We should retrieve the key and // it should make it into the cache automatically.
it should make it into the cache automatically.
*/
req := gomatrixserverlib.PublicKeyLookupRequest{ req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverB.name, ServerName: serverB.name,
@ -161,12 +180,10 @@ func TestCachingBehaviour(t *testing.T) {
t.Fatalf("server B isn't included in the key fetch response") t.Fatalf("server B isn't included in the key fetch response")
} }
/* // At this point, if the previous key request was a success,
At this point, if the previous key request was a success, // then the cache should now contain the key. Check if that's
then the cache should now contain the key. Check if that's // the case - if it isn't then there's something wrong with
the case - if it isn't then there's something wrong with // the cache implementation or we failed to get the key.
the cache implementation.
*/
cres, ok := serverA.cache.GetServerKey(req, ts) cres, ok := serverA.cache.GetServerKey(req, ts)
if !ok { if !ok {
@ -176,11 +193,9 @@ func TestCachingBehaviour(t *testing.T) {
t.Fatalf("the cached result from server B wasn't what server B gave us") t.Fatalf("the cached result from server B wasn't what server B gave us")
} }
/* // If we ask the cache for the same key but this time for an event
If we ask the cache for the server key + 30 minutes, then // that happened in +30 minutes. Since the validity period is for
it should still be valid, as server B's validity period is // another hour, then we should get a response back from the cache.
an hour.
*/
_, ok = serverA.cache.GetServerKey( _, ok = serverA.cache.GetServerKey(
req, req,
@ -190,11 +205,10 @@ func TestCachingBehaviour(t *testing.T) {
t.Fatalf("server B key isn't in cache when it should be (+30 minutes)") t.Fatalf("server B key isn't in cache when it should be (+30 minutes)")
} }
/* // If we ask the cache for the same key but this time for an event
If we ask the cache for the server key + 90 minutes, then // that happened in +90 minutes then we should expect to get no
it will have passed the validity by that point, so we should // cache result. This is because the cache shouldn't return a result
expect to get no response. // that is obviously past the validity of the event.
*/
_, ok = serverA.cache.GetServerKey( _, ok = serverA.cache.GetServerKey(
req, req,
@ -206,10 +220,9 @@ func TestCachingBehaviour(t *testing.T) {
} }
func TestRenewalBehaviour(t *testing.T) { func TestRenewalBehaviour(t *testing.T) {
/* // Server A will request Server C's key but their validity period
Server A will request Server C's key but their validity period // is an hour in the past. We'll retrieve the key as, even though it's
is an hour in the past. We'll retrieve the key. // past its validity, it will be able to verify past events.
*/
req := gomatrixserverlib.PublicKeyLookupRequest{ req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverC.name, ServerName: serverC.name,
@ -232,12 +245,11 @@ func TestRenewalBehaviour(t *testing.T) {
t.Fatalf("server C isn't included in the key fetch response") t.Fatalf("server C isn't included in the key fetch response")
} }
t.Log("server C's key expires at", res[req].ValidUntilTS.Time()) // If we ask the cache for the server key for an event that happened
// 90 minutes ago then we should get a cache result, as the key hadn't
/* // passed its validity by that point. The fact that the key is now in
If we ask the cache for the server key - 90 minutes, then // the cache is, in itself, proof that we successfully retrieved the
it should be valid as the key hadn't expired by that point. // key before.
*/
oldcached, ok := serverA.cache.GetServerKey( oldcached, ok := serverA.cache.GetServerKey(
req, req,
@ -247,11 +259,10 @@ func TestRenewalBehaviour(t *testing.T) {
t.Fatalf("server C key isn't in cache when it should be (-90 minutes)") t.Fatalf("server C key isn't in cache when it should be (-90 minutes)")
} }
/* // If we now ask the cache for the same key but this time for an event
If we ask the cache for the server key - 30 minutes, then // that only happened 30 minutes ago then we shouldn't get a cached
it will have passed the validity by that point, so we should // result, as the event happened after the key validity expired. This
expect to get no response. // is really just for sanity checking.
*/
_, ok = serverA.cache.GetServerKey( _, ok = serverA.cache.GetServerKey(
req, req,
@ -261,12 +272,10 @@ func TestRenewalBehaviour(t *testing.T) {
t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)") t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)")
} }
/* // We're now going to kick server C into renewing its key. Since we're
We're now going to kick server C into renewing its key. // happy at this point that the key that we already have is from the past
Since we've asserted by this point that the key isn't going // then repeating a key fetch should cause us to try and renew the key.
to be returned by the cache, then we should really spot that // If so, then the new key will end up in our cache.
the key needs to be renewed and then do so.
*/
serverC.renew() serverC.renew()
@ -286,11 +295,10 @@ func TestRenewalBehaviour(t *testing.T) {
t.Fatalf("server C isn't included in the key fetch response") t.Fatalf("server C isn't included in the key fetch response")
} }
/* // We're now going to ask the cache what the new key validity is. If
We're now going to ask the cache what the new key validity // it is still the same as the previous validity then we've failed to
is. If it is still the same as the previous validity then we've // retrieve the renewed key. If it's newer then we've successfully got
failed to retrieve the renewed key. // the renewed key.
*/
newcached, ok := serverA.cache.GetServerKey( newcached, ok := serverA.cache.GetServerKey(
req, req,
@ -299,8 +307,7 @@ func TestRenewalBehaviour(t *testing.T) {
if !ok { if !ok {
t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)") t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)")
} }
if oldcached.ValidUntilTS >= newcached.ValidUntilTS {
if oldcached.ValidUntilTS == newcached.ValidUntilTS {
t.Fatalf("the server B key should have been renewed but wasn't") t.Fatalf("the server B key should have been renewed but wasn't")
} }
} }