diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go index 743e034fa..68a77cf99 100644 --- a/common/basecomponent/base.go +++ b/common/basecomponent/base.go @@ -54,12 +54,12 @@ type BaseDendrite struct { tracerCloser io.Closer // APIMux should be used to register new public matrix api endpoints - APIMux *mux.Router - httpClient *http.Client - Cfg *config.Dendrite - Cache caching.Cache - KafkaConsumer sarama.Consumer - KafkaProducer sarama.SyncProducer + APIMux *mux.Router + httpClient *http.Client + Cfg *config.Dendrite + ImmutableCache caching.ImmutableCache + KafkaConsumer sarama.Consumer + KafkaProducer sarama.SyncProducer } const HTTPServerTimeout = time.Minute * 5 @@ -85,20 +85,20 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite { kafkaConsumer, kafkaProducer = setupKafka(cfg) } - cache, err := caching.NewInMemoryLRUCache() + cache, err := caching.NewImmutableInMemoryLRUCache() if err != nil { logrus.WithError(err).Warnf("Failed to create cache") } return &BaseDendrite{ - componentName: componentName, - tracerCloser: closer, - Cfg: cfg, - Cache: cache, - APIMux: mux.NewRouter().UseEncodedPath(), - httpClient: &http.Client{Timeout: HTTPClientTimeout}, - KafkaConsumer: kafkaConsumer, - KafkaProducer: kafkaProducer, + componentName: componentName, + tracerCloser: closer, + Cfg: cfg, + ImmutableCache: cache, + APIMux: mux.NewRouter().UseEncodedPath(), + httpClient: &http.Client{Timeout: HTTPClientTimeout}, + KafkaConsumer: kafkaConsumer, + KafkaProducer: kafkaProducer, } } @@ -132,7 +132,7 @@ func (b *BaseDendrite) CreateHTTPRoomserverAPIs() ( if err != nil { logrus.WithError(err).Panic("NewRoomserverInputAPIHTTP failed", b.httpClient) } - query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient, b.Cache) + query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient, b.ImmutableCache) if err != nil { logrus.WithError(err).Panic("NewRoomserverQueryAPIHTTP failed", b.httpClient) } diff --git a/common/caching/cache.go b/common/caching/immutablecache.go similarity index 69% rename from common/caching/cache.go rename to common/caching/immutablecache.go index 83404fecf..9b967f3f6 100644 --- a/common/caching/cache.go +++ b/common/caching/immutablecache.go @@ -3,10 +3,11 @@ package caching import "github.com/matrix-org/gomatrixserverlib" const ( - MaxRoomVersionCacheEntries = 128 + RoomVersionCachingEnabled = true + RoomVersionMaxCacheEntries = 128 ) -type Cache interface { +type ImmutableCache interface { GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool) StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion) } diff --git a/common/caching/immutableinmemorylru.go b/common/caching/immutableinmemorylru.go new file mode 100644 index 000000000..1d43a267c --- /dev/null +++ b/common/caching/immutableinmemorylru.go @@ -0,0 +1,47 @@ +package caching + +import ( + lru "github.com/hashicorp/golang-lru" + "github.com/matrix-org/gomatrixserverlib" +) + +type ImmutableInMemoryLRUCache struct { + roomVersions *lru.Cache +} + +func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { + roomVersionCache, rvErr := lru.New(RoomVersionMaxCacheEntries) + if rvErr != nil { + return nil, rvErr + } + return &ImmutableInMemoryLRUCache{ + roomVersions: roomVersionCache, + }, nil +} + +func checkForInvalidMutation(cache *lru.Cache, key string, value interface{}) { + if peek, ok := cache.Peek(key); ok && peek != value { + panic("invalid use of immutable cache tries to mutate existing value") + } +} + +func (c *ImmutableInMemoryLRUCache) GetRoomVersion(roomID string) (gomatrixserverlib.RoomVersion, bool) { + if c == nil || !RoomVersionCachingEnabled { + return "", false + } + val, found := c.roomVersions.Get(roomID) + if found && val != nil { + if roomVersion, ok := val.(gomatrixserverlib.RoomVersion); ok { + return roomVersion, true + } + } + return "", false +} + +func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) { + if c == nil || !RoomVersionCachingEnabled { + return + } + checkForInvalidMutation(c.roomVersions, roomID, roomVersion) + c.roomVersions.Add(roomID, roomVersion) +} diff --git a/common/caching/inmemorylru.go b/common/caching/inmemorylru.go deleted file mode 100644 index 89101f4a2..000000000 --- a/common/caching/inmemorylru.go +++ /dev/null @@ -1,40 +0,0 @@ -package caching - -import ( - lru "github.com/hashicorp/golang-lru" - "github.com/matrix-org/gomatrixserverlib" -) - -type InMemoryLRUCache struct { - roomVersions *lru.Cache -} - -func NewInMemoryLRUCache() (*InMemoryLRUCache, error) { - roomVersionCache, rvErr := lru.New(MaxRoomVersionCacheEntries) - if rvErr != nil { - return nil, rvErr - } - return &InMemoryLRUCache{ - roomVersions: roomVersionCache, - }, nil -} - -func (c *InMemoryLRUCache) GetRoomVersion(roomID string) (gomatrixserverlib.RoomVersion, bool) { - if c == nil { - return "", false - } - val, found := c.roomVersions.Get(roomID) - if found && val != nil { - if roomVersion, ok := val.(gomatrixserverlib.RoomVersion); ok { - return roomVersion, true - } - } - return "", false -} - -func (c *InMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) { - if c == nil { - return - } - c.roomVersions.Add(roomID, roomVersion) -} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b78bc4a67..b272b1ebd 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -412,7 +412,7 @@ const RoomserverQueryRoomVersionForRoomPath = "/api/roomserver/queryRoomVersionF // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil an error is returned -func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client, cache caching.Cache) (RoomserverQueryAPI, error) { +func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client, cache caching.ImmutableCache) (RoomserverQueryAPI, error) { if httpClient == nil { return nil, errors.New("NewRoomserverQueryAPIHTTP: httpClient is ") } @@ -420,9 +420,9 @@ func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client, ca } type httpRoomserverQueryAPI struct { - roomserverURL string - httpClient *http.Client - cache caching.Cache + roomserverURL string + httpClient *http.Client + immutableCache caching.ImmutableCache } // QueryLatestEventsAndState implements RoomserverQueryAPI @@ -587,7 +587,7 @@ func (h *httpRoomserverQueryAPI) QueryRoomVersionForRoom( request *QueryRoomVersionForRoomRequest, response *QueryRoomVersionForRoomResponse, ) error { - if roomVersion, ok := h.cache.GetRoomVersion(request.RoomID); ok { + if roomVersion, ok := h.immutableCache.GetRoomVersion(request.RoomID); ok { response.RoomVersion = roomVersion return nil } @@ -598,7 +598,7 @@ func (h *httpRoomserverQueryAPI) QueryRoomVersionForRoom( apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath err := commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) if err == nil { - h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) + h.immutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) } return err } diff --git a/roomserver/query/query.go b/roomserver/query/query.go index 73fc37e98..224d9fa22 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -98,8 +98,8 @@ type RoomserverQueryAPIDatabase interface { // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI type RoomserverQueryAPI struct { - DB RoomserverQueryAPIDatabase - Cache caching.Cache + DB RoomserverQueryAPIDatabase + ImmutableCache caching.ImmutableCache } // QueryLatestEventsAndState implements api.RoomserverQueryAPI @@ -898,7 +898,7 @@ func (r *RoomserverQueryAPI) QueryRoomVersionForRoom( request *api.QueryRoomVersionForRoomRequest, response *api.QueryRoomVersionForRoomResponse, ) error { - if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok { + if roomVersion, ok := r.ImmutableCache.GetRoomVersion(request.RoomID); ok { response.RoomVersion = roomVersion return nil } @@ -908,7 +908,7 @@ func (r *RoomserverQueryAPI) QueryRoomVersionForRoom( return err } response.RoomVersion = roomVersion - r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) + r.ImmutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index f5b658968..fa4f20626 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -49,8 +49,8 @@ func SetupRoomServerComponent( inputAPI.SetupHTTP(http.DefaultServeMux) queryAPI := query.RoomserverQueryAPI{ - DB: roomserverDB, - Cache: base.Cache, + DB: roomserverDB, + ImmutableCache: base.ImmutableCache, } queryAPI.SetupHTTP(http.DefaultServeMux)