From 9907ffaf2cb0ffd10a7024fafc4c26c420e7a0d8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 11 Oct 2022 16:33:47 +0100 Subject: [PATCH] Clean up a bit --- syncapi/storage/shared/storage_sync.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 00f62d8ff..199182d09 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -598,17 +598,24 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, limit int) ( clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error, ) { - var eventIDs []string + clientEvents = []gomatrixserverlib.ClientEvent{} + eventIDs := []string{} r := types.Range{ From: from, To: to, Backwards: from > to, } + + // First up look up any relations from the database. We add one to the limit here + // so that we can tell if we're overflowing, as we will only set the "next_batch" + // in the response if we are. relations, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, r, limit+1) if err != nil { return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) } + // If we specified a relation type then just get those results, otherwise collate + // them from all of the returned relation types. entries := []types.RelationEntry{} if relType != "" { entries = relations[relType] @@ -618,10 +625,13 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, } } + // If there were no entries returned, there were no relations, so stop at this point. if len(entries) == 0 { return nil, "", "", nil } + // Otherwise, let's try and work out what sensible prev_batch and next_batch values + // could be. if from > 0 { prevBatch = fmt.Sprintf("%d", entries[0].Position-1) } @@ -630,15 +640,18 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, entries = entries[:len(entries)-1] } + // Extract all of the event IDs from the relation entries so that we can pull the + // events out of the database. for _, entry := range entries { eventIDs = append(eventIDs, entry.EventID) } - events, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) if err != nil { return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) } + // Convert the events into client events, and optionally filter based on the event + // type if it was specified. for _, event := range events { if eventType != "" && event.Type() != eventType { continue