Tweak caching behaviour, update tests

This commit is contained in:
Neil Alexander 2020-06-15 15:03:06 +01:00
parent aabd45995a
commit e630190272
5 changed files with 117 additions and 40 deletions

View file

@ -2,7 +2,6 @@ package caching
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -16,22 +15,21 @@ const (
// ServerKeyCache contains the subset of functions needed for // ServerKeyCache contains the subset of functions needed for
// a server key cache. // a server key cache.
type ServerKeyCache interface { type ServerKeyCache interface {
GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp gomatrixserverlib.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool)
StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult) StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult)
} }
func (c Caches) GetServerKey( func (c Caches) GetServerKey(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
timestamp gomatrixserverlib.Timestamp,
) (gomatrixserverlib.PublicKeyLookupResult, bool) { ) (gomatrixserverlib.PublicKeyLookupResult, bool) {
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
now := gomatrixserverlib.AsTimestamp(time.Now())
val, found := c.ServerKeys.Get(key) val, found := c.ServerKeys.Get(key)
if found && val != nil { if found && val != nil {
if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok {
if !keyLookupResult.WasValidAt(now, true) { if !keyLookupResult.WasValidAt(timestamp, true) {
// We appear to be past the key validity so don't return this // The key wasn't valid at the requested timestamp so don't
// with the results. This ensures that the cache doesn't return // return it. The caller will have to work out what to do.
// values that are not useful to us.
c.ServerKeys.Unset(key) c.ServerKeys.Unset(key)
return gomatrixserverlib.PublicKeyLookupResult{}, false return gomatrixserverlib.PublicKeyLookupResult{}, false
} }

View file

@ -173,9 +173,18 @@ func (s *ServerKeyAPI) FetchKeys(
} }
} }
} }
// If we failed to fetch any keys then we should report an error. // Check that we've actually satisfied all of the key requests that we
if len(requests) > 0 { // were given. We should report an error if we didn't.
return results, fmt.Errorf("server key API failed to fetch %d keys", len(requests)) for req := range requests {
if _, ok := results[req]; !ok {
// The results don't contain anything for this specific request, so
// we've failed to satisfy it from local keys, database keys or from
// all of the fetchers. Report an error.
return results, fmt.Errorf(
"server key API failed to satisfy key request for server %q key ID %q",
req.ServerName, req.KeyID,
)
}
} }
// Return the keys. // Return the keys.
return results, nil return results, nil

View file

@ -91,7 +91,7 @@ func (s *httpServerKeyInternalAPI) FetchKeys(
Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult),
} }
for req, ts := range requests { for req, ts := range requests {
if res, ok := s.cache.GetServerKey(req); ok { if res, ok := s.cache.GetServerKey(req, ts); ok {
result[req] = res result[req] = res
continue continue
} }

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -27,6 +28,11 @@ type server struct {
api api.ServerKeyInternalAPI api api.ServerKeyInternalAPI
} }
func (s *server) renew() {
s.validity = time.Hour
s.config.Matrix.KeyValidityPeriod = s.validity
}
var ( var (
serverKeyID = gomatrixserverlib.KeyID("ed25519:auto") serverKeyID = gomatrixserverlib.KeyID("ed25519:auto")
serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now
@ -70,7 +76,7 @@ func TestMain(m *testing.M) {
s.api = NewInternalAPI(s.config, s.fedclient, s.cache) s.api = NewInternalAPI(s.config, s.fedclient, s.cache)
} }
//os.Exit(m.Run()) os.Exit(m.Run())
} }
type MockRoundTripper struct{} type MockRoundTripper struct{}
@ -92,8 +98,6 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
return nil, err return nil, err
} }
fmt.Println("Round-tripper says:", string(body))
res = &http.Response{ res = &http.Response{
StatusCode: 200, StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(body)), Body: ioutil.NopCloser(bytes.NewReader(body)),
@ -124,11 +128,11 @@ func TestServersRequestOwnKeys(t *testing.T) {
if _, ok := res[req]; !ok { if _, ok := res[req]; !ok {
t.Fatalf("server didn't return its own key in the results") t.Fatalf("server didn't return its own key in the results")
} }
fmt.Printf("%s's key expires at %d\n", name, res[req].ValidUntilTS) t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time())
} }
} }
func TestServerARequestsServerBKey(t *testing.T) { func TestCachingBehaviour(t *testing.T) {
/* /*
Server A will request Server B's key, which has a validity Server A will request Server B's key, which has a validity
period of an hour from now. We should retrieve the key and period of an hour from now. We should retrieve the key and
@ -139,11 +143,12 @@ func TestServerARequestsServerBKey(t *testing.T) {
ServerName: serverB.name, ServerName: serverB.name,
KeyID: serverKeyID, KeyID: serverKeyID,
} }
ts := gomatrixserverlib.AsTimestamp(time.Now())
res, err := serverA.api.FetchKeys( res, err := serverA.api.FetchKeys(
context.Background(), context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: gomatrixserverlib.AsTimestamp(time.Now()), req: ts,
}, },
) )
if err != nil { if err != nil {
@ -163,7 +168,7 @@ func TestServerARequestsServerBKey(t *testing.T) {
the cache implementation. the cache implementation.
*/ */
cres, ok := serverA.cache.GetServerKey(req) cres, ok := serverA.cache.GetServerKey(req, ts)
if !ok { if !ok {
t.Fatalf("server B key should be in cache but isn't") t.Fatalf("server B key should be in cache but isn't")
} }
@ -172,27 +177,38 @@ func TestServerARequestsServerBKey(t *testing.T) {
} }
/* /*
Server A will then request Server B's key for an event that If we ask the cache for the server key + 30 minutes, then
happened two hours ago, which *should* pass since the key was it should still be valid, as server B's validity period is
valid then too and it's already in the cache. an hour.
*/ */
_, err = serverA.api.FetchKeys( _, ok = serverA.cache.GetServerKey(
context.Background(), req,
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*30)),
req: gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Hour * 2)),
},
) )
if err != nil { if !ok {
t.Fatalf("server A failed to retrieve server B key: %s", err) t.Fatalf("server B key isn't in cache when it should be (+30 minutes)")
}
/*
If we ask the cache for the server key + 90 minutes, then
it will have passed the validity by that point, so we should
expect to get no response.
*/
_, ok = serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*90)),
)
if ok {
t.Fatalf("server B key is in cache when it shouldn't be (+90 minutes)")
} }
} }
func TestServerARequestsServerCKey(t *testing.T) { func TestRenewalBehaviour(t *testing.T) {
/* /*
Server A will request Server C's key for an event that came Server A will request Server C's key but their validity period
in just now, but their validity period is an hour in the is an hour in the past. We'll retrieve the key.
past. This *should* fail since the key isn't valid now.
*/ */
req := gomatrixserverlib.PublicKeyLookupRequest{ req := gomatrixserverlib.PublicKeyLookupRequest{
@ -219,18 +235,72 @@ func TestServerARequestsServerCKey(t *testing.T) {
t.Log("server C's key expires at", res[req].ValidUntilTS.Time()) t.Log("server C's key expires at", res[req].ValidUntilTS.Time())
/* /*
Server A will then request Server C's key for an event that If we ask the cache for the server key - 90 minutes, then
happened two hours ago, which *should* pass since the key was it should be valid as the key hadn't expired by that point.
valid then.
*/ */
_, err = serverA.api.FetchKeys( oldcached, ok := serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*90)),
)
if !ok {
t.Fatalf("server C key isn't in cache when it should be (-90 minutes)")
}
/*
If we ask the cache for the server key - 30 minutes, then
it will have passed the validity by that point, so we should
expect to get no response.
*/
_, ok = serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
)
if ok {
t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)")
}
/*
We're now going to kick server C into renewing its key.
Since we've asserted by this point that the key isn't going
to be returned by the cache, then we should really spot that
the key needs to be renewed and then do so.
*/
serverC.renew()
res, err = serverA.api.FetchKeys(
context.Background(), context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Hour)), req: gomatrixserverlib.AsTimestamp(time.Now()),
}, },
) )
if err != nil { if err != nil {
t.Fatalf("serverKeyAPI.FetchKeys: %s", err) t.Fatalf("server A failed to retrieve server C key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
}
if _, ok = res[req]; !ok {
t.Fatalf("server C isn't included in the key fetch response")
}
/*
We're now going to ask the cache what the new key validity
is. If it is still the same as the previous validity then we've
failed to retrieve the renewed key.
*/
newcached, ok := serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
)
if !ok {
t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)")
}
if oldcached.ValidUntilTS == newcached.ValidUntilTS {
t.Fatalf("the server B key should have been renewed but wasn't")
} }
} }

View file

@ -39,8 +39,8 @@ func (d *KeyDatabase) FetchKeys(
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
for req := range requests { for req, ts := range requests {
if res, cached := d.cache.GetServerKey(req); cached { if res, cached := d.cache.GetServerKey(req, ts); cached {
results[req] = res results[req] = res
delete(requests, req) delete(requests, req)
} }