diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 2a952d328..0b227c8e6 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -151,13 +151,13 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP remaining := relation.Limit - len(returnEvents) var walkLimited bool if remaining > 0 { - depths := make(map[string]int, len(returnEvents)) + included := make(map[string]bool, len(returnEvents)) for _, ev := range returnEvents { - depths[ev.EventID()] = 1 + included[ev.EventID()] = true } var events []*gomatrixserverlib.HeaderedEvent events, walkLimited = walkThread( - req.Context(), db, rsAPI, device.UserID, &relation, depths, remaining, + req.Context(), db, rsAPI, device.UserID, &relation, included, remaining, ) returnEvents = append(returnEvents, events...) } @@ -211,7 +211,7 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI // honouring the limit, max_depth and max_breadth values according to the following rules // nolint: unparam func walkThread( - ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]int, limit int, + ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int, ) ([]*gomatrixserverlib.HeaderedEvent, bool) { if req.Direction != "down" { util.GetLogger(ctx).Error("not implemented: direction=up") @@ -222,45 +222,31 @@ func walkThread( ctx: ctx, req: req, db: db, + fn: func(wi *walkInfo) bool { + // If already processed event, skip. + if included[wi.EventID] { + return false + } + + // If the response array is >= limit, stop. + if len(result) >= limit { + return true + } + + // Process the event. + event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID) + if event != nil { + result = append(result, event) + } + included[wi.EventID] = true + return false + }, } - parent, current := eventWalker.Next() - for current.EventID != "" { - // If the response array is >= limit, stop. - if len(result) >= limit { - return result, true - } - // If already processed event, skip. - if included[current.EventID] > 0 { - continue - } - - // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip. - parentDepth := included[parent] - if parentDepth == 0 { - util.GetLogger(ctx).Errorf("parent has unknown depth; this should be impossible, parent=%s curr=%v map=%v", parent, current, included) - // set these at the max to stop walking this part of the DAG - included[parent] = req.MaxDepth - included[current.EventID] = req.MaxDepth - continue - } - depth := parentDepth + 1 - if depth > req.MaxDepth { - continue - } - - // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip. - if current.SiblingNumber > req.MaxBreadth { - continue - } - - // Process the event. - event := getEventIfVisible(ctx, rsAPI, current.EventID, userID) - if event != nil { - result = append(result, event) - } - included[current.EventID] = depth + limited, err := eventWalker.WalkFrom(req.EventID) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID) } - return result, false + return result, limited } func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent { @@ -300,25 +286,90 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA type walkInfo struct { eventInfo SiblingNumber int + Depth int } type walker struct { - ctx context.Context - req *EventRelationshipRequest - db Database - current string - //toProcess []walkInfo + ctx context.Context + req *EventRelationshipRequest + db Database + fn func(wi *walkInfo) bool // callback invoked for each event walked, return true to terminate the walk } -// Next returns the next event to process. -func (w *walker) Next() (parent string, current walkInfo) { - //var events []string - - _, err := w.db.ChildrenForParent(w.ctx, w.current, constRelType, *w.req.RecentFirst) +// WalkFrom the event ID given +func (w *walker) WalkFrom(eventID string) (limited bool, err error) { + children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, *w.req.RecentFirst) if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("Next() failed, cannot walk") - return + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + var next *walkInfo + toWalk := w.addChildren(nil, children, 1) + next, toWalk = w.nextChild(toWalk) + for next != nil { + stop := w.fn(next) + if stop { + return true, nil + } + // find the children's children + children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, *w.req.RecentFirst) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + toWalk = w.addChildren(toWalk, children, next.Depth+1) + next, toWalk = w.nextChild(toWalk) } - return + return false, nil +} + +// addChildren adds an event's children to the to walk data structure +func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChildren int) []walkInfo { + // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip. + if len(children) > w.req.MaxBreadth { + children = children[:w.req.MaxBreadth] + } + // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip. + if depthOfChildren > w.req.MaxDepth { + return toWalk + } + + if *w.req.DepthFirst { + // the slice is a stack so push them in reverse order so we pop them in the correct order + // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3 + for i := len(children) - 1; i >= 0; i-- { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } else { + // the slice is a queue so push them in normal order to we dequeue them in the correct order + // e.g [1,2,3] => 1, [2, 3] => 2 , [3] => 3, [] + for i := range children { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } + return toWalk +} + +func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { + if len(toWalk) == 0 { + return nil, nil + } + var child walkInfo + if *w.req.DepthFirst { + // toWalk is a stack so pop the child off + child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1] + return &child, toWalk + } + // toWalk is a queue so shift the child off + child, toWalk = toWalk[0], toWalk[1:] + return &child, toWalk } diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index 1b09117dd..91e20c716 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -228,14 +228,14 @@ func TestMSC2836(t *testing.T) { EventID: eventD.EventID(), IncludeChildren: &constTrue, RecentFirst: &constTrue, - Limit: 10, + Limit: 4, }) assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ EventID: eventD.EventID(), IncludeChildren: &constTrue, RecentFirst: &constFalse, - Limit: 10, + Limit: 4, }) assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) })