diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index cf57658da..febc43c74 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" "net/http" - "sort" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/hooks" @@ -187,7 +186,7 @@ func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, // Apply history visibility checks to all these events and add the ones which pass into the response array, // honouring the recent_first flag and the limit. func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { - children, err := db.ChildrenForParent(ctx, parentID, "m.reference") + children, err := db.ChildrenForParent(ctx, parentID, "m.reference", recentFirst) if err != nil { util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent") resErr := jsonerror.InternalServerError() @@ -195,18 +194,11 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI } var childEvents []*gomatrixserverlib.HeaderedEvent for _, child := range children { - childEvent := getEventIfVisible(ctx, rsAPI, child, userID) + childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID) if childEvent != nil { childEvents = append(childEvents, childEvent) } } - // sort childEvents by origin_server_ts in ASC or DESC depending on recent_first - sort.SliceStable(childEvents, func(i, j int) bool { - if recentFirst { - return childEvents[i].OriginServerTS().Time().After(childEvents[j].OriginServerTS().Time()) - } - return childEvents[i].OriginServerTS().Time().Before(childEvents[j].OriginServerTS().Time()) - }) if len(childEvents) > limit { return childEvents[:limit], nil } @@ -318,8 +310,8 @@ func newWalker(req *EventRelationshipRequest) walker { } type depthWalker struct { - req *EventRelationshipRequest - db Database + req *EventRelationshipRequest + // db Database current string } diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go index 4e975054a..3eabcabf3 100644 --- a/internal/mscs/msc2836/storage.go +++ b/internal/mscs/msc2836/storage.go @@ -10,21 +10,29 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type eventInfo struct { + EventID string + OriginServerTS gomatrixserverlib.Timestamp + RoomID string +} + type Database interface { // StoreRelation stores the parent->child and child->parent relationship for later querying. // Also stores the event metadata e.g timestamp StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error - // ChildrenForParent returns the event IDs of events who have the given `eventID` as an m.relationship with the - // provided `relType`. - ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error) + // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the + // provided `relType`. The returned slice is sorted by origin_server_ts according to whether + // `recentFirst` is true or false. + ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) } type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - insertNodeStmt *sql.Stmt - selectChildrenForParentStmt *sql.Stmt + db *sql.DB + writer sqlutil.Writer + insertEdgeStmt *sql.Stmt + insertNodeStmt *sql.Stmt + selectChildrenForParentOldestFirstStmt *sql.Stmt + selectChildrenForParentRecentFirstStmt *sql.Stmt } // NewDatabase loads the database for msc2836 @@ -70,9 +78,16 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } - if d.selectChildrenForParentStmt, err = d.db.Prepare(` - SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2 - `); err != nil { + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } return &d, err @@ -113,9 +128,16 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } - if d.selectChildrenForParentStmt, err = d.db.Prepare(` - SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2 - `); err != nil { + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } return &d, nil @@ -136,19 +158,25 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv }) } -func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error) { - rows, err := p.selectChildrenForParentStmt.QueryContext(ctx, eventID, relType) +func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { + var rows *sql.Rows + var err error + if recentFirst { + rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType) + } else { + rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType) + } if err != nil { return nil, err } defer rows.Close() // nolint: errcheck - var children []string + var children []eventInfo for rows.Next() { - var childID string - if err := rows.Scan(&childID); err != nil { + var evInfo eventInfo + if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil { return nil, err } - children = append(children, childID) + children = append(children, evInfo) } return children, nil }