Handle direction: up

This commit is contained in:
Kegan Dougal 2020-11-25 12:04:23 +00:00
parent e278ce93af
commit 095cc3e7bf
3 changed files with 65 additions and 8 deletions

View file

@ -322,10 +322,6 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
func walkThread( func walkThread(
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) { ) ([]*gomatrixserverlib.HeaderedEvent, bool) {
if rc.req.Direction != "down" {
util.GetLogger(ctx).Error("not implemented: direction=up")
return nil, false
}
var result []*gomatrixserverlib.HeaderedEvent var result []*gomatrixserverlib.HeaderedEvent
eventWalker := walker{ eventWalker := walker{
ctx: ctx, ctx: ctx,
@ -620,9 +616,9 @@ type walker struct {
// WalkFrom the event ID given // WalkFrom the event ID given
func (w *walker) WalkFrom(eventID string) (limited bool, err error) { func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) children, err := w.childrenForParent(eventID)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err return false, err
} }
var next *walkInfo var next *walkInfo
@ -634,9 +630,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
return true, nil return true, nil
} }
// find the children's children // find the children's children
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) children, err = w.childrenForParent(next.EventID)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err return false, err
} }
toWalk = w.addChildren(toWalk, children, next.Depth+1) toWalk = w.addChildren(toWalk, children, next.Depth+1)
@ -695,3 +691,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
child, toWalk = toWalk[0], toWalk[1:] child, toWalk = toWalk[0], toWalk[1:]
return &child, toWalk return &child, toWalk
} }
// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags
// meaning this can actually be returning the parent for the event instead of the children.
func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) {
if w.req.Direction == "down" {
return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
}
// find the event to pull out the parent
ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType)
if err != nil {
return nil, err
}
if ei != nil {
return []eventInfo{*ei}, nil
}
return nil, nil
}

View file

@ -349,6 +349,24 @@ func TestMSC2836(t *testing.T) {
})) }))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
}) })
t.Run("can navigate up the graph with direction: up", func(t *testing.T) {
// A4
// |
// B3
// / \
// C D2
// /| \
// E F1 G
// |
// H
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventF.EventID(),
"recent_first": false,
"depth_first": true,
"direction": "up",
}))
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
})
} }
// TODO: TestMSC2836TerminatesLoops (short and long) // TODO: TestMSC2836TerminatesLoops (short and long)

View file

@ -25,6 +25,10 @@ type Database interface {
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether // provided `relType`. The returned slice is sorted by origin_server_ts according to whether
// `recentFirst` is true or false. // `recentFirst` is true or false.
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
// there is no parent for this child event, with no error. The parent eventInfo can be missing the
// timestamp if the event is not known to the server.
ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
} }
type DB struct { type DB struct {
@ -34,6 +38,7 @@ type DB struct {
insertNodeStmt *sql.Stmt insertNodeStmt *sql.Stmt
selectChildrenForParentOldestFirstStmt *sql.Stmt selectChildrenForParentOldestFirstStmt *sql.Stmt
selectChildrenForParentRecentFirstStmt *sql.Stmt selectChildrenForParentRecentFirstStmt *sql.Stmt
selectParentForChildStmt *sql.Stmt
} }
// NewDatabase loads the database for msc2836 // NewDatabase loads the database for msc2836
@ -93,6 +98,11 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
return &d, err return &d, err
} }
@ -145,6 +155,11 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
return &d, nil return &d, nil
} }
@ -191,6 +206,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec
return children, nil return children, nil
} }
func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
var ei eventInfo
err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &ei, nil
}
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
if ev == nil { if ev == nil {
return return