diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 6bc2b3733..6c5da00c9 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -75,15 +75,9 @@ func Relations(req *http.Request, device *api.Device, syncDB storage.Database, r var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) - if to == 0 { - if to, err = snapshot.MaxStreamPositionForRelations(req.Context()); err != nil { - return util.ErrorResponse(err) - } - } - res := &RelationsResponse{} res.Chunk, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( - req.Context(), roomID, eventID, relType, eventType, from, to, limit, + req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit, ) if err != nil { return util.ErrorResponse(err) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 3135f1a88..99a9db540 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -108,7 +108,7 @@ type DatabaseTransaction interface { GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) - RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, limit int) (clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error) + RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error) } type Database interface { diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 934fddbf0..0b2be7632 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -50,12 +50,12 @@ const deleteRelationSQL = "" + const selectRelationsInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + - " WHERE room_id = $1 AND event_id = $2 AND id > $3 AND id <= $4" + + " WHERE room_id = $1 AND event_id = $2 AND id >= $3 AND id <= $4" + " ORDER BY id DESC LIMIT $5" const selectRelationsByTypeInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + - " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id > $4 AND id <= $5" + + " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id >= $4 AND id <= $5" + " ORDER BY id DESC LIMIT $6" const selectMaxRelationIDSQL = "" + diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 199182d09..19e95d655 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -595,7 +595,7 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) return types.StreamPosition(id), err } -func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, limit int) ( +func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) ( clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error, ) { clientEvents = []gomatrixserverlib.ClientEvent{} @@ -603,7 +603,16 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, r := types.Range{ From: from, To: to, - Backwards: from > to, + Backwards: backwards, + } + if r.To == 0 && !backwards { + if r.To, err = d.MaxStreamPositionForRelations(ctx); err != nil { + return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err) + } + } else if r.From == 0 && backwards { + if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil { + return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err) + } } // First up look up any relations from the database. We add one to the limit here @@ -633,7 +642,7 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, // Otherwise, let's try and work out what sensible prev_batch and next_batch values // could be. if from > 0 { - prevBatch = fmt.Sprintf("%d", entries[0].Position-1) + prevBatch = fmt.Sprintf("%d", r.Low()-1) } if len(entries) > limit { nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position) diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 4f25777ea..e5a1fad62 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -48,12 +48,12 @@ const deleteRelationSQL = "" + const selectRelationsInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + - " WHERE room_id = $1 AND event_id = $2 AND id > $3 AND id <= $4" + + " WHERE room_id = $1 AND event_id = $2 AND id >= $3 AND id <= $4" + " ORDER BY id DESC LIMIT $5" const selectRelationsByTypeInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + - " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id > $4 AND id <= $5" + + " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id >= $4 AND id <= $5" + " ORDER BY id DESC LIMIT $6" const selectMaxRelationIDSQL = "" +