Clean up a bit

This commit is contained in:
Neil Alexander 2022-10-11 16:33:47 +01:00
parent 6811998611
commit 9907ffaf2c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -598,17 +598,24 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context)
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, limit int) (
clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error,
) {
var eventIDs []string
clientEvents = []gomatrixserverlib.ClientEvent{}
eventIDs := []string{}
r := types.Range{
From: from,
To: to,
Backwards: from > to,
}
// First up look up any relations from the database. We add one to the limit here
// so that we can tell if we're overflowing, as we will only set the "next_batch"
// in the response if we are.
relations, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, r, limit+1)
if err != nil {
return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err)
}
// If we specified a relation type then just get those results, otherwise collate
// them from all of the returned relation types.
entries := []types.RelationEntry{}
if relType != "" {
entries = relations[relType]
@ -618,10 +625,13 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
}
}
// If there were no entries returned, there were no relations, so stop at this point.
if len(entries) == 0 {
return nil, "", "", nil
}
// Otherwise, let's try and work out what sensible prev_batch and next_batch values
// could be.
if from > 0 {
prevBatch = fmt.Sprintf("%d", entries[0].Position-1)
}
@ -630,15 +640,18 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
entries = entries[:len(entries)-1]
}
// Extract all of the event IDs from the relation entries so that we can pull the
// events out of the database.
for _, entry := range entries {
eventIDs = append(eventIDs, entry.EventID)
}
events, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true)
if err != nil {
return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
}
// Convert the events into client events, and optionally filter based on the event
// type if it was specified.
for _, event := range events {
if eventType != "" && event.Type() != eventType {
continue