mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Handle direction: up
This commit is contained in:
parent
e278ce93af
commit
095cc3e7bf
|
|
@ -322,10 +322,6 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
|
||||||
func walkThread(
|
func walkThread(
|
||||||
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
|
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
|
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
|
||||||
if rc.req.Direction != "down" {
|
|
||||||
util.GetLogger(ctx).Error("not implemented: direction=up")
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
var result []*gomatrixserverlib.HeaderedEvent
|
var result []*gomatrixserverlib.HeaderedEvent
|
||||||
eventWalker := walker{
|
eventWalker := walker{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|
@ -620,9 +616,9 @@ type walker struct {
|
||||||
|
|
||||||
// WalkFrom the event ID given
|
// WalkFrom the event ID given
|
||||||
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
||||||
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
|
children, err := w.childrenForParent(eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
var next *walkInfo
|
var next *walkInfo
|
||||||
|
|
@ -634,9 +630,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
// find the children's children
|
// find the children's children
|
||||||
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
|
children, err = w.childrenForParent(next.EventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
toWalk = w.addChildren(toWalk, children, next.Depth+1)
|
toWalk = w.addChildren(toWalk, children, next.Depth+1)
|
||||||
|
|
@ -695,3 +691,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
|
||||||
child, toWalk = toWalk[0], toWalk[1:]
|
child, toWalk = toWalk[0], toWalk[1:]
|
||||||
return &child, toWalk
|
return &child, toWalk
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags
|
||||||
|
// meaning this can actually be returning the parent for the event instead of the children.
|
||||||
|
func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) {
|
||||||
|
if w.req.Direction == "down" {
|
||||||
|
return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
|
||||||
|
}
|
||||||
|
// find the event to pull out the parent
|
||||||
|
ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ei != nil {
|
||||||
|
return []eventInfo{*ei}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -349,6 +349,24 @@ func TestMSC2836(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
|
||||||
})
|
})
|
||||||
|
t.Run("can navigate up the graph with direction: up", func(t *testing.T) {
|
||||||
|
// A4
|
||||||
|
// |
|
||||||
|
// B3
|
||||||
|
// / \
|
||||||
|
// C D2
|
||||||
|
// /| \
|
||||||
|
// E F1 G
|
||||||
|
// |
|
||||||
|
// H
|
||||||
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
|
"event_id": eventF.EventID(),
|
||||||
|
"recent_first": false,
|
||||||
|
"depth_first": true,
|
||||||
|
"direction": "up",
|
||||||
|
}))
|
||||||
|
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: TestMSC2836TerminatesLoops (short and long)
|
// TODO: TestMSC2836TerminatesLoops (short and long)
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,10 @@ type Database interface {
|
||||||
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
|
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
|
||||||
// `recentFirst` is true or false.
|
// `recentFirst` is true or false.
|
||||||
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
|
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
|
||||||
|
// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
|
||||||
|
// there is no parent for this child event, with no error. The parent eventInfo can be missing the
|
||||||
|
// timestamp if the event is not known to the server.
|
||||||
|
ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
|
|
@ -34,6 +38,7 @@ type DB struct {
|
||||||
insertNodeStmt *sql.Stmt
|
insertNodeStmt *sql.Stmt
|
||||||
selectChildrenForParentOldestFirstStmt *sql.Stmt
|
selectChildrenForParentOldestFirstStmt *sql.Stmt
|
||||||
selectChildrenForParentRecentFirstStmt *sql.Stmt
|
selectChildrenForParentRecentFirstStmt *sql.Stmt
|
||||||
|
selectParentForChildStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase loads the database for msc2836
|
// NewDatabase loads the database for msc2836
|
||||||
|
|
@ -93,6 +98,11 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||||
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
|
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if d.selectParentForChildStmt, err = d.db.Prepare(`
|
||||||
|
SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &d, err
|
return &d, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -145,6 +155,11 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||||
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
|
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if d.selectParentForChildStmt, err = d.db.Prepare(`
|
||||||
|
SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,6 +206,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec
|
||||||
return children, nil
|
return children, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
|
||||||
|
var ei eventInfo
|
||||||
|
err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ei, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
|
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
|
||||||
if ev == nil {
|
if ev == nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue