diff --git a/internal/caching/cache_roominfo.go b/internal/caching/cache_roominfo.go index d03a61077..5dfed3c85 100644 --- a/internal/caching/cache_roominfo.go +++ b/internal/caching/cache_roominfo.go @@ -16,18 +16,18 @@ import ( // a room Info cache. It must only be used from the roomserver only // It is not safe for use from other components. type RoomInfoCache interface { - GetRoomInfo(roomID string) (roomInfo types.RoomInfo, ok bool) - StoreRoomInfo(roomID string, roomInfo types.RoomInfo) + GetRoomInfo(roomID string) (roomInfo *types.RoomInfo, ok bool) + StoreRoomInfo(roomID string, roomInfo *types.RoomInfo) } // GetRoomInfo must only be called from the roomserver only. It is not // safe for use from other components. -func (c Caches) GetRoomInfo(roomID string) (types.RoomInfo, bool) { +func (c Caches) GetRoomInfo(roomID string) (*types.RoomInfo, bool) { return c.RoomInfos.Get(roomID) } // StoreRoomInfo must only be called from the roomserver only. It is not // safe for use from other components. -func (c Caches) StoreRoomInfo(roomID string, roomInfo types.RoomInfo) { +func (c Caches) StoreRoomInfo(roomID string, roomInfo *types.RoomInfo) { c.RoomInfos.Set(roomID, roomInfo) } diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 14b232dd0..e7914ce7d 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -28,7 +28,7 @@ type Caches struct { RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID RoomServerRoomIDs Cache[int64, string] // room NID -> room ID RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event - RoomInfos Cache[string, types.RoomInfo] // room ID -> room info + RoomInfos Cache[string, *types.RoomInfo] // room ID -> room info FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index 6d625b552..677218b5e 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -100,7 +100,7 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm MaxAge: maxAge, }, }, - RoomInfos: &RistrettoCachePartition[string, types.RoomInfo]{ // room ID -> room info + RoomInfos: &RistrettoCachePartition[string, *types.RoomInfo]{ // room ID -> room info cache: cache, Prefix: roomInfosCache, Mutable: true, diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 8f4e011bf..c35ac653c 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -225,13 +225,12 @@ func (u *RoomUpdater) SetLatestEvents( if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) } - if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { - roomInfo.StateSnapshotNID = currentStateSnapshotNID - roomInfo.IsStub = false - u.d.Cache.StoreRoomInfo(roomID, roomInfo) - } - } + + // Since it's entirely possible that this types.RoomInfo came from the + // cache, we should make sure to update that entry so that the next run + // works from live data. + u.roomInfo.StateSnapshotNID = currentStateSnapshotNID + u.roomInfo.IsStub = false return nil }) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 692af1f6c..d8d5f67c8 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -139,13 +139,13 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo } func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { - if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { - return &roomInfo, nil + if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok && roomInfo != nil { + return roomInfo, nil } roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) if err == nil && roomInfo != nil { d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) - d.Cache.StoreRoomInfo(roomID, *roomInfo) + d.Cache.StoreRoomInfo(roomID, roomInfo) } return roomInfo, err }