diff --git a/cmd/roomserver-integration-tests/main.go b/cmd/roomserver-integration-tests/main.go index df5607bcb..682fc6224 100644 --- a/cmd/roomserver-integration-tests/main.go +++ b/cmd/roomserver-integration-tests/main.go @@ -28,6 +28,7 @@ import ( "net/http" + "github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/common/test" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -253,6 +254,11 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R panic(err) } + cache, err := caching.NewImmutableInMemoryLRUCache() + if err != nil { + panic(err) + } + doInput := func() { fmt.Printf("Roomserver is ready to receive input, sending %d events\n", len(input)) if err = writeToRoomServer(input, cfg.RoomServerURL()); err != nil { @@ -270,7 +276,7 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R cmd.Args = []string{"dendrite-room-server", "--config", filepath.Join(dir, test.ConfigFile)} gotOutput, err := runAndReadFromTopic(cmd, cfg.RoomServerURL()+"/metrics", doInput, outputTopic, len(wantOutput), func() { - queryAPI, _ := api.NewRoomserverQueryAPIHTTP("http://"+string(cfg.Listen.RoomServer), &http.Client{Timeout: timeoutHTTP}) + queryAPI, _ := api.NewRoomserverQueryAPIHTTP("http://"+string(cfg.Listen.RoomServer), &http.Client{Timeout: timeoutHTTP}, cache) checkQueries(queryAPI) }) if err != nil { diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go index 78894289e..68a77cf99 100644 --- a/common/basecomponent/base.go +++ b/common/basecomponent/base.go @@ -23,6 +23,7 @@ import ( "golang.org/x/crypto/ed25519" + "github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -53,11 +54,12 @@ type BaseDendrite struct { tracerCloser io.Closer // APIMux should be used to register new public matrix api endpoints - APIMux *mux.Router - httpClient *http.Client - Cfg *config.Dendrite - KafkaConsumer sarama.Consumer - KafkaProducer sarama.SyncProducer + APIMux *mux.Router + httpClient *http.Client + Cfg *config.Dendrite + ImmutableCache caching.ImmutableCache + KafkaConsumer sarama.Consumer + KafkaProducer sarama.SyncProducer } const HTTPServerTimeout = time.Minute * 5 @@ -83,14 +85,20 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite { kafkaConsumer, kafkaProducer = setupKafka(cfg) } + cache, err := caching.NewImmutableInMemoryLRUCache() + if err != nil { + logrus.WithError(err).Warnf("Failed to create cache") + } + return &BaseDendrite{ - componentName: componentName, - tracerCloser: closer, - Cfg: cfg, - APIMux: mux.NewRouter().UseEncodedPath(), - httpClient: &http.Client{Timeout: HTTPClientTimeout}, - KafkaConsumer: kafkaConsumer, - KafkaProducer: kafkaProducer, + componentName: componentName, + tracerCloser: closer, + Cfg: cfg, + ImmutableCache: cache, + APIMux: mux.NewRouter().UseEncodedPath(), + httpClient: &http.Client{Timeout: HTTPClientTimeout}, + KafkaConsumer: kafkaConsumer, + KafkaProducer: kafkaProducer, } } @@ -116,7 +124,6 @@ func (b *BaseDendrite) CreateHTTPRoomserverAPIs() ( roomserverAPI.RoomserverInputAPI, roomserverAPI.RoomserverQueryAPI, ) { - alias, err := roomserverAPI.NewRoomserverAliasAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient) if err != nil { logrus.WithError(err).Panic("NewRoomserverAliasAPIHTTP failed") @@ -125,7 +132,7 @@ func (b *BaseDendrite) CreateHTTPRoomserverAPIs() ( if err != nil { logrus.WithError(err).Panic("NewRoomserverInputAPIHTTP failed", b.httpClient) } - query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient) + query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient, b.ImmutableCache) if err != nil { logrus.WithError(err).Panic("NewRoomserverQueryAPIHTTP failed", b.httpClient) } diff --git a/common/caching/immutablecache.go b/common/caching/immutablecache.go new file mode 100644 index 000000000..9620667a2 --- /dev/null +++ b/common/caching/immutablecache.go @@ -0,0 +1,12 @@ +package caching + +import "github.com/matrix-org/gomatrixserverlib" + +const ( + RoomVersionMaxCacheEntries = 128 +) + +type ImmutableCache interface { + GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool) + StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion) +} diff --git a/common/caching/immutableinmemorylru.go b/common/caching/immutableinmemorylru.go new file mode 100644 index 000000000..3e8f4aadb --- /dev/null +++ b/common/caching/immutableinmemorylru.go @@ -0,0 +1,43 @@ +package caching + +import ( + "fmt" + + lru "github.com/hashicorp/golang-lru" + "github.com/matrix-org/gomatrixserverlib" +) + +type ImmutableInMemoryLRUCache struct { + roomVersions *lru.Cache +} + +func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { + roomVersionCache, rvErr := lru.New(RoomVersionMaxCacheEntries) + if rvErr != nil { + return nil, rvErr + } + return &ImmutableInMemoryLRUCache{ + roomVersions: roomVersionCache, + }, nil +} + +func checkForInvalidMutation(cache *lru.Cache, key string, value interface{}) { + if peek, ok := cache.Peek(key); ok && peek != value { + panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + } +} + +func (c *ImmutableInMemoryLRUCache) GetRoomVersion(roomID string) (gomatrixserverlib.RoomVersion, bool) { + val, found := c.roomVersions.Get(roomID) + if found && val != nil { + if roomVersion, ok := val.(gomatrixserverlib.RoomVersion); ok { + return roomVersion, true + } + } + return "", false +} + +func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) { + checkForInvalidMutation(c.roomVersions, roomID, roomVersion) + c.roomVersions.Add(roomID, roomVersion) +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 5f024d266..b272b1ebd 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -21,6 +21,7 @@ import ( "errors" "net/http" + "github.com/matrix-org/dendrite/common/caching" commonHTTP "github.com/matrix-org/dendrite/common/http" "github.com/matrix-org/gomatrixserverlib" opentracing "github.com/opentracing/opentracing-go" @@ -411,16 +412,17 @@ const RoomserverQueryRoomVersionForRoomPath = "/api/roomserver/queryRoomVersionF // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil an error is returned -func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) (RoomserverQueryAPI, error) { +func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client, cache caching.ImmutableCache) (RoomserverQueryAPI, error) { if httpClient == nil { return nil, errors.New("NewRoomserverQueryAPIHTTP: httpClient is ") } - return &httpRoomserverQueryAPI{roomserverURL, httpClient}, nil + return &httpRoomserverQueryAPI{roomserverURL, httpClient, cache}, nil } type httpRoomserverQueryAPI struct { - roomserverURL string - httpClient *http.Client + roomserverURL string + httpClient *http.Client + immutableCache caching.ImmutableCache } // QueryLatestEventsAndState implements RoomserverQueryAPI @@ -585,9 +587,18 @@ func (h *httpRoomserverQueryAPI) QueryRoomVersionForRoom( request *QueryRoomVersionForRoomRequest, response *QueryRoomVersionForRoomResponse, ) error { + if roomVersion, ok := h.immutableCache.GetRoomVersion(request.RoomID); ok { + response.RoomVersion = roomVersion + return nil + } + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + err := commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err == nil { + h.immutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) + } + return err } diff --git a/roomserver/query/query.go b/roomserver/query/query.go index 12d8436ef..224d9fa22 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -22,6 +22,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/state" @@ -97,7 +98,8 @@ type RoomserverQueryAPIDatabase interface { // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI type RoomserverQueryAPI struct { - DB RoomserverQueryAPIDatabase + DB RoomserverQueryAPIDatabase + ImmutableCache caching.ImmutableCache } // QueryLatestEventsAndState implements api.RoomserverQueryAPI @@ -896,11 +898,17 @@ func (r *RoomserverQueryAPI) QueryRoomVersionForRoom( request *api.QueryRoomVersionForRoomRequest, response *api.QueryRoomVersionForRoomResponse, ) error { + if roomVersion, ok := r.ImmutableCache.GetRoomVersion(request.RoomID); ok { + response.RoomVersion = roomVersion + return nil + } + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) if err != nil { return err } response.RoomVersion = roomVersion + r.ImmutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 2ffbf67de..fa4f20626 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -48,7 +48,10 @@ func SetupRoomServerComponent( inputAPI.SetupHTTP(http.DefaultServeMux) - queryAPI := query.RoomserverQueryAPI{DB: roomserverDB} + queryAPI := query.RoomserverQueryAPI{ + DB: roomserverDB, + ImmutableCache: base.ImmutableCache, + } queryAPI.SetupHTTP(http.DefaultServeMux)