diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index aedd85010..0845127ab 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -107,6 +107,8 @@ type DatabaseSnapshot interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) + GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) } type Database interface { @@ -172,9 +174,8 @@ type Database interface { } type Presence interface { - UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) + UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } type SharedUsers interface { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 73042d3d1..c48d1d651 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1104,8 +1104,12 @@ func (d *Database) GetPresence(ctx context.Context, userID string) (*types.Prese return d.Presence.GetPresenceForUser(ctx, nil, userID) } -func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return d.Presence.GetPresenceAfter(ctx, nil, after, filter) +func (d *DatabaseSnapshot) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +} + +func (d *DatabaseSnapshot) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) } func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index b6ea28d3b..b2f9a0b47 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -61,7 +61,7 @@ func (p *PresenceStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { // We pull out a larger number than the filter asks for, since we're filtering out events later - presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) + presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -89,7 +89,7 @@ func (p *PresenceStreamProvider) IncrementalSync( } // Bear in mind that this might return nil, but at least populating // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) + presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) if err != nil { req.Log.WithError(err).Error("unable to query presence for user") return from