LEFT JOIN to extract origin_server_ts for children

This commit is contained in:
Kegan Dougal 2020-11-03 15:43:39 +00:00
parent 21e61636e1
commit 4dc9f3efd4
2 changed files with 52 additions and 32 deletions

View file

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sort"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/hooks"
@ -187,7 +186,7 @@ func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI,
// Apply history visibility checks to all these events and add the ones which pass into the response array, // Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit. // honouring the recent_first flag and the limit.
func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
children, err := db.ChildrenForParent(ctx, parentID, "m.reference") children, err := db.ChildrenForParent(ctx, parentID, "m.reference", recentFirst)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent") util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
@ -195,18 +194,11 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI
} }
var childEvents []*gomatrixserverlib.HeaderedEvent var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children { for _, child := range children {
childEvent := getEventIfVisible(ctx, rsAPI, child, userID) childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID)
if childEvent != nil { if childEvent != nil {
childEvents = append(childEvents, childEvent) childEvents = append(childEvents, childEvent)
} }
} }
// sort childEvents by origin_server_ts in ASC or DESC depending on recent_first
sort.SliceStable(childEvents, func(i, j int) bool {
if recentFirst {
return childEvents[i].OriginServerTS().Time().After(childEvents[j].OriginServerTS().Time())
}
return childEvents[i].OriginServerTS().Time().Before(childEvents[j].OriginServerTS().Time())
})
if len(childEvents) > limit { if len(childEvents) > limit {
return childEvents[:limit], nil return childEvents[:limit], nil
} }
@ -318,8 +310,8 @@ func newWalker(req *EventRelationshipRequest) walker {
} }
type depthWalker struct { type depthWalker struct {
req *EventRelationshipRequest req *EventRelationshipRequest
db Database // db Database
current string current string
} }

View file

@ -10,21 +10,29 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type eventInfo struct {
EventID string
OriginServerTS gomatrixserverlib.Timestamp
RoomID string
}
type Database interface { type Database interface {
// StoreRelation stores the parent->child and child->parent relationship for later querying. // StoreRelation stores the parent->child and child->parent relationship for later querying.
// Also stores the event metadata e.g timestamp // Also stores the event metadata e.g timestamp
StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
// ChildrenForParent returns the event IDs of events who have the given `eventID` as an m.relationship with the // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
// provided `relType`. // provided `relType`. The returned slice is sorted by origin_server_ts according to whether
ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error) // `recentFirst` is true or false.
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
} }
type DB struct { type DB struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
insertEdgeStmt *sql.Stmt insertEdgeStmt *sql.Stmt
insertNodeStmt *sql.Stmt insertNodeStmt *sql.Stmt
selectChildrenForParentStmt *sql.Stmt selectChildrenForParentOldestFirstStmt *sql.Stmt
selectChildrenForParentRecentFirstStmt *sql.Stmt
} }
// NewDatabase loads the database for msc2836 // NewDatabase loads the database for msc2836
@ -70,9 +78,16 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.selectChildrenForParentStmt, err = d.db.Prepare(` selectChildrenQuery := `
SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2 SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
`); err != nil { LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
WHERE parent_event_id = $1 AND rel_type = $2
ORDER BY origin_server_ts
`
if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
return nil, err
}
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
return &d, err return &d, err
@ -113,9 +128,16 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.selectChildrenForParentStmt, err = d.db.Prepare(` selectChildrenQuery := `
SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2 SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
`); err != nil { LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
WHERE parent_event_id = $1 AND rel_type = $2
ORDER BY origin_server_ts
`
if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
return nil, err
}
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
@ -136,19 +158,25 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv
}) })
} }
func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error) { func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
rows, err := p.selectChildrenForParentStmt.QueryContext(ctx, eventID, relType) var rows *sql.Rows
var err error
if recentFirst {
rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType)
} else {
rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
var children []string var children []eventInfo
for rows.Next() { for rows.Next() {
var childID string var evInfo eventInfo
if err := rows.Scan(&childID); err != nil { if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil {
return nil, err return nil, err
} }
children = append(children, childID) children = append(children, evInfo)
} }
return children, nil return children, nil
} }