diff --git a/internal/caching/cache_lazy_load_members.go b/internal/caching/cache_lazy_load_members.go index f0614f09f..903575430 100644 --- a/internal/caching/cache_lazy_load_members.go +++ b/internal/caching/cache_lazy_load_members.go @@ -3,6 +3,8 @@ package caching import ( "fmt" "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) const ( @@ -13,11 +15,23 @@ const ( ) type LazyLoadCache struct { - *InMemoryLRUCachePartition + // Mapping from userID/deviceID to InMemoryLRUCachePartition + userCaches map[string]*InMemoryLRUCachePartition } -// NewLazyLoadCache creates a new InMemoryLRUCachePartition. -func NewLazyLoadCache() (*LazyLoadCache, error) { +// NewLazyLoadCache creates a new LazyLoadCache. +func NewLazyLoadCache() *LazyLoadCache { + return &LazyLoadCache{ + userCaches: make(map[string]*InMemoryLRUCachePartition), + } +} + +func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) { + cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID) + cache, ok := c.userCaches[cacheName] + if ok { + return cache, nil + } cache, err := NewInMemoryLRUCachePartition( LazyLoadCacheName, LazyLoadCacheMutable, @@ -28,18 +42,28 @@ func NewLazyLoadCache() (*LazyLoadCache, error) { if err != nil { return nil, err } + c.userCaches[cacheName] = cache go cacheCleaner(cache) - return &LazyLoadCache{cache}, err + return cache, nil } -func (c *LazyLoadCache) StoreLazyLoadedUser(reqUser, deviceID, roomID, userID, eventID string) { - cacheKey := fmt.Sprintf("%s/%s/%s/%s", reqUser, deviceID, roomID, userID) - c.Set(cacheKey, eventID) +func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return + } + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + cache.Set(cacheKey, eventID) } -func (c *LazyLoadCache) IsLazyLoadedUserCached(reqUser, deviceID, roomID, userID string) (string, bool) { - cacheKey := fmt.Sprintf("%s/%s/%s/%s", reqUser, deviceID, roomID, userID) - val, ok := c.Get(cacheKey) +func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return "", false + } + + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + val, ok := cache.Get(cacheKey) if !ok { return "", ok } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 80233da13..a658b702b 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -28,7 +28,7 @@ type PDUStreamProvider struct { tasks chan func() workers atomic.Int32 // userID+deviceID -> lazy loading cache - lazyLoadCache map[string]*caching.LazyLoadCache + lazyLoadCache *caching.LazyLoadCache } func (p *PDUStreamProvider) worker() { @@ -267,14 +267,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } if stateFilter.LazyLoadMembers { - cache, err := p.getLazyLoadCache(device) if err != nil { return r.From, err } delta.StateEvents, err = p.lazyLoadMembers( ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, - device, cache, - recentEvents, delta.StateEvents, + device, recentEvents, delta.StateEvents, ) if err != nil { return r.From, err @@ -422,14 +420,12 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( stateEvents = removeDuplicates(stateEvents, recentEvents) if stateFilter.LazyLoadMembers { - cache, err := p.getLazyLoadCache(device) if err != nil { return nil, err } stateEvents, err = p.lazyLoadMembers(ctx, roomID, false, limited, stateFilter.IncludeRedundantMembers, - device, cache, - recentEvents, stateEvents, + device, recentEvents, stateEvents, ) if err != nil { return nil, err @@ -450,7 +446,6 @@ func (p *PDUStreamProvider) lazyLoadMembers( ctx context.Context, roomID string, incremental, limited, includeRedundant bool, device *userapi.Device, - cache *caching.LazyLoadCache, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { if len(timelineEvents) == 0 { @@ -464,7 +459,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := cache.IsLazyLoadedUserCached(device.UserID, device.ID, roomID, event.Sender()); !ok { + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { timelineUsers[event.Sender()] = struct{}{} } } @@ -481,7 +476,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( if wantMembership { newStateEvents = append(newStateEvents, event) if !includeRedundant { - cache.StoreLazyLoadedUser(device.UserID, device.ID, roomID, event.Sender(), event.EventID()) + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID()) } delete(timelineUsers, event.Sender()) } @@ -504,27 +499,12 @@ func (p *PDUStreamProvider) lazyLoadMembers( } // cache the membership events for _, membership := range memberships { - cache.StoreLazyLoadedUser(device.UserID, device.ID, roomID, membership.Sender(), membership.EventID()) + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID()) } stateEvents = append(newStateEvents, memberships...) return stateEvents, nil } -// getLazyLoadCache gets/creates a lazy load cache for a given device. -func (p *PDUStreamProvider) getLazyLoadCache(device *userapi.Device) (*caching.LazyLoadCache, error) { - var err error - cacheKey := device.UserID + device.ID - cache, ok := p.lazyLoadCache[cacheKey] - if !ok { - cache, err = caching.NewLazyLoadCache() - if err != nil { - return nil, err - } - p.lazyLoadCache[cacheKey] = cache - } - return cache, nil -} - // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 99560966b..d3195b78f 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -27,7 +27,7 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.UserInternalAPI, rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, - eduCache *caching.EDUCache, lazyLoadCache map[string]*caching.LazyLoadCache, notifier *notifier.Notifier, + eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 4ecf9076b..90c2b57dc 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -57,8 +57,8 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() + lazyLoadCache := caching.NewLazyLoadCache() notifier := notifier.NewNotifier() - lazyLoadCache := make(map[string]*caching.LazyLoadCache) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil {