From ebb70ef68fd5b4aae9981c5537e6b854219197e9 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 3 Nov 2020 10:51:15 +0000 Subject: [PATCH] Flesh out walkThread a bit --- internal/mscs/msc2836/msc2836.go | 209 ++++++++++++++++++-------- internal/mscs/msc2836/msc2836_test.go | 21 ++- 2 files changed, 164 insertions(+), 66 deletions(-) diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 28619557c..2c91367ed 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -98,73 +98,83 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us }) base.PublicClientAPIMux.Handle("/unstable/event_relationships", - httputil.MakeAuthAPI("eventRelationships", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - var relation EventRelationshipRequest - if err := json.NewDecoder(req.Body).Decode(&relation); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), - } - } - // Sanity check request and set defaults. - relation.applyDefaults() - var res EventRelationshipResponse - var returnEvents []*gomatrixserverlib.HeaderedEvent - - // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID) - if event == nil { - return util.JSONResponse{ - Code: 403, - JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), - } - } - - // Retrieve the event. Add it to response array. - returnEvents = append(returnEvents, event) - - if *relation.IncludeParent { - if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil { - returnEvents = append(returnEvents, parentEvent) - } - } - - if *relation.IncludeChildren { - remaining := relation.Limit - len(returnEvents) - if remaining > 0 { - children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID) - if resErr != nil { - return *resErr - } - returnEvents = append(returnEvents, children...) - } - } - - remaining := relation.Limit - len(returnEvents) - var walkLimited bool - if remaining > 0 { - // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, - // honouring the limit, max_depth and max_breadth values according to the following rules - var events []*gomatrixserverlib.HeaderedEvent - events, walkLimited = walkThread(req.Context(), db, remaining) - returnEvents = append(returnEvents, events...) - } - res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) - for i, ev := range returnEvents { - res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll) - } - res.Limited = remaining == 0 || walkLimited - - return util.JSONResponse{ - Code: 200, - JSON: res, - } - }), + httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)), ).Methods(http.MethodPost, http.MethodOptions) return nil } +func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { + return func(req *http.Request, device *userapi.Device) util.JSONResponse { + var relation EventRelationshipRequest + if err := json.NewDecoder(req.Body).Decode(&relation); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + // Sanity check request and set defaults. + relation.applyDefaults() + var res EventRelationshipResponse + var returnEvents []*gomatrixserverlib.HeaderedEvent + + // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. + event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID) + if event == nil { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + } + } + + // Retrieve the event. Add it to response array. + returnEvents = append(returnEvents, event) + + if *relation.IncludeParent { + if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil { + returnEvents = append(returnEvents, parentEvent) + } + } + + if *relation.IncludeChildren { + remaining := relation.Limit - len(returnEvents) + if remaining > 0 { + children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID) + if resErr != nil { + return *resErr + } + returnEvents = append(returnEvents, children...) + } + } + + remaining := relation.Limit - len(returnEvents) + var walkLimited bool + if remaining > 0 { + depths := make(map[string]int, len(returnEvents)) + for _, ev := range returnEvents { + depths[ev.EventID()] = 1 + } + // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, + // honouring the limit, max_depth and max_breadth values according to the following rules + var events []*gomatrixserverlib.HeaderedEvent + events, walkLimited = walkThread( + req.Context(), db, rsAPI, device.UserID, &relation, depths, remaining, + ) + returnEvents = append(returnEvents, events...) + } + res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) + for i, ev := range returnEvents { + res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll) + } + res.Limited = remaining == 0 || walkLimited + + return util.JSONResponse{ + Code: 200, + JSON: res, + } + } +} + // If include_parent: true and there is a valid m.relationship field in the event, // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) { @@ -205,8 +215,50 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI return childEvents, nil } -func walkThread(ctx context.Context, db Database, limit int) ([]*gomatrixserverlib.HeaderedEvent, bool) { - return nil, false +// nolint: unparam +func walkThread( + ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]int, limit int, +) ([]*gomatrixserverlib.HeaderedEvent, bool) { + var result []*gomatrixserverlib.HeaderedEvent + eventsToWalk := newWalker(req) + parent, siblingNum, current := eventsToWalk.Next() + for current != "" { + // If the response array is >= limit, stop. + if len(result) >= limit { + return result, true + } + // If already processed event, skip. + if included[current] > 0 { + continue + } + + // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip. + parentDepth := included[parent] + if parentDepth == 0 { + util.GetLogger(ctx).Errorf("parent has unknown depth; this should be impossible, parent=%s curr=%s map=%v", parent, current, included) + // set these at the max to stop walking this part of the DAG + included[parent] = req.MaxDepth + included[current] = req.MaxDepth + continue + } + depth := parentDepth + 1 + if depth > req.MaxDepth { + continue + } + + // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip. + if siblingNum > req.MaxBreadth { + continue + } + + // Process the event. + event := getEventIfVisible(ctx, rsAPI, current, userID) + if event != nil { + result = append(result, event) + } + included[current] = depth + } + return result, false } func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent { @@ -242,3 +294,30 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA } return &event } + +type walker interface { + Next() (parent string, siblingNum int, current string) +} + +func newWalker(req *EventRelationshipRequest) walker { + if *req.DepthFirst { + return &depthWalker{req} + } + return &breadthWalker{req} +} + +type depthWalker struct { + req *EventRelationshipRequest +} + +func (w *depthWalker) Next() (parent string, siblingNum int, current string) { + return "", 0, "" +} + +type breadthWalker struct { + req *EventRelationshipRequest +} + +func (w *breadthWalker) Next() (parent string, siblingNum int, current string) { + return "", 0, "" +} diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index be6af4068..1b09117dd 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -26,7 +26,8 @@ var ( client = &http.Client{ Timeout: 10 * time.Second, } - constTrue = true + constTrue = true + constFalse = false ) // Basic sanity check of MSC2836 logic. Injects a thread that looks like: @@ -222,6 +223,22 @@ func TestMSC2836(t *testing.T) { }) assertContains(t, body, []string{eventB.EventID(), eventA.EventID()}) }) + t.Run("returns the children in the right order if include_children is true", func(t *testing.T) { + body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ + EventID: eventD.EventID(), + IncludeChildren: &constTrue, + RecentFirst: &constTrue, + Limit: 10, + }) + assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) + body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ + EventID: eventD.EventID(), + IncludeChildren: &constTrue, + RecentFirst: &constFalse, + Limit: 10, + }) + assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) } func runServer(t *testing.T, router *mux.Router) func() { @@ -416,6 +433,8 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib if err != nil { t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) } + // make sure the origin_server_ts changes so we can test recency + time.Sleep(1 * time.Millisecond) signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) if err != nil { t.Fatalf("mustCreateEvent: failed to sign event: %s", err)