From 095cc3e7bf5114bc89f2306aded4f7744057f1d5 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 25 Nov 2020 12:04:23 +0000 Subject: [PATCH] Handle direction: up --- internal/mscs/msc2836/msc2836.go | 29 +++++++++++++++++++-------- internal/mscs/msc2836/msc2836_test.go | 18 +++++++++++++++++ internal/mscs/msc2836/storage.go | 26 ++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index cbff7acdd..c68a169e1 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -322,10 +322,6 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen func walkThread( ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ) ([]*gomatrixserverlib.HeaderedEvent, bool) { - if rc.req.Direction != "down" { - util.GetLogger(ctx).Error("not implemented: direction=up") - return nil, false - } var result []*gomatrixserverlib.HeaderedEvent eventWalker := walker{ ctx: ctx, @@ -620,9 +616,9 @@ type walker struct { // WalkFrom the event ID given func (w *walker) WalkFrom(eventID string) (limited bool, err error) { - children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + children, err := w.childrenForParent(eventID) if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk") return false, err } var next *walkInfo @@ -634,9 +630,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) { return true, nil } // find the children's children - children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) + children, err = w.childrenForParent(next.EventID) if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk") return false, err } toWalk = w.addChildren(toWalk, children, next.Depth+1) @@ -695,3 +691,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { child, toWalk = toWalk[0], toWalk[1:] return &child, toWalk } + +// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags +// meaning this can actually be returning the parent for the event instead of the children. +func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) { + if w.req.Direction == "down" { + return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + } + // find the event to pull out the parent + ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType) + if err != nil { + return nil, err + } + if ei != nil { + return []eventInfo{*ei}, nil + } + return nil, nil +} diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index 66c0ae4d6..d91f368ca 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -349,6 +349,24 @@ func TestMSC2836(t *testing.T) { })) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) }) + t.Run("can navigate up the graph with direction: up", func(t *testing.T) { + // A4 + // | + // B3 + // / \ + // C D2 + // /| \ + // E F1 G + // | + // H + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventF.EventID(), + "recent_first": false, + "depth_first": true, + "direction": "up", + })) + assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()}) + }) } // TODO: TestMSC2836TerminatesLoops (short and long) diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go index f524165fa..3c128a114 100644 --- a/internal/mscs/msc2836/storage.go +++ b/internal/mscs/msc2836/storage.go @@ -25,6 +25,10 @@ type Database interface { // provided `relType`. The returned slice is sorted by origin_server_ts according to whether // `recentFirst` is true or false. ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) + // ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if + // there is no parent for this child event, with no error. The parent eventInfo can be missing the + // timestamp if the event is not known to the server. + ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) } type DB struct { @@ -34,6 +38,7 @@ type DB struct { insertNodeStmt *sql.Stmt selectChildrenForParentOldestFirstStmt *sql.Stmt selectChildrenForParentRecentFirstStmt *sql.Stmt + selectParentForChildStmt *sql.Stmt } // NewDatabase loads the database for msc2836 @@ -93,6 +98,11 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } + if d.selectParentForChildStmt, err = d.db.Prepare(` + SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } return &d, err } @@ -145,6 +155,11 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } + if d.selectParentForChildStmt, err = d.db.Prepare(` + SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } return &d, nil } @@ -191,6 +206,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec return children, nil } +func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) { + var ei eventInfo + err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + return &ei, nil +} + func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { if ev == nil { return