diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 0b227c8e6..cbf01c07d 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -44,6 +44,7 @@ type EventRelationshipRequest struct { IncludeChildren *bool `json:"include_children"` Direction string `json:"direction"` Batch string `json:"batch"` + AutoJoin bool `json:"auto_join"` } func (r *EventRelationshipRequest) applyDefaults() { @@ -104,6 +105,13 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us return nil } +type reqCtx struct { + ctx context.Context + rsAPI roomserver.RoomserverInternalAPI + req *EventRelationshipRequest + userID string +} + func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse { var relation EventRelationshipRequest @@ -118,9 +126,15 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP relation.applyDefaults() var res EventRelationshipResponse var returnEvents []*gomatrixserverlib.HeaderedEvent + rc := reqCtx{ + ctx: req.Context(), + req: &relation, + userID: device.UserID, + rsAPI: rsAPI, + } // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID) + event := rc.getEventIfVisible(relation.EventID) if event == nil { return util.JSONResponse{ Code: 403, @@ -132,7 +146,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP returnEvents = append(returnEvents, event) if *relation.IncludeParent { - if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil { + if parentEvent := rc.includeParent(event); parentEvent != nil { returnEvents = append(returnEvents, parentEvent) } } @@ -140,7 +154,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP if *relation.IncludeChildren { remaining := relation.Limit - len(returnEvents) if remaining > 0 { - children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID) + children, resErr := rc.includeChildren(db, event.EventID(), remaining, *relation.RecentFirst) if resErr != nil { return *resErr } @@ -157,7 +171,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP } var events []*gomatrixserverlib.HeaderedEvent events, walkLimited = walkThread( - req.Context(), db, rsAPI, device.UserID, &relation, included, remaining, + req.Context(), db, &rc, included, remaining, ) returnEvents = append(returnEvents, events...) } @@ -176,27 +190,27 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP // If include_parent: true and there is a valid m.relationship field in the event, // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. -func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) { +func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { parentID, _, _ := parentChildEventIDs(event) if parentID == "" { return nil } - return getEventIfVisible(ctx, rsAPI, parentID, userID) + return rc.getEventIfVisible(parentID) } // If include_children: true, lookup all events which have event_id as an m.relationship // Apply history visibility checks to all these events and add the ones which pass into the response array, // honouring the recent_first flag and the limit. -func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { - children, err := db.ChildrenForParent(ctx, parentID, constRelType, recentFirst) +func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { + children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) if err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent") + util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") resErr := jsonerror.InternalServerError() return nil, &resErr } var childEvents []*gomatrixserverlib.HeaderedEvent for _, child := range children { - childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID) + childEvent := rc.getEventIfVisible(child.EventID) if childEvent != nil { childEvents = append(childEvents, childEvent) } @@ -211,16 +225,16 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI // honouring the limit, max_depth and max_breadth values according to the following rules // nolint: unparam func walkThread( - ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int, + ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ) ([]*gomatrixserverlib.HeaderedEvent, bool) { - if req.Direction != "down" { + if rc.req.Direction != "down" { util.GetLogger(ctx).Error("not implemented: direction=up") return nil, false } var result []*gomatrixserverlib.HeaderedEvent eventWalker := walker{ ctx: ctx, - req: req, + req: rc.req, db: db, fn: func(wi *walkInfo) bool { // If already processed event, skip. @@ -234,7 +248,7 @@ func walkThread( } // Process the event. - event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID) + event := rc.getEventIfVisible(wi.EventID) if event != nil { result = append(result, event) } @@ -242,24 +256,24 @@ func walkThread( return false }, } - limited, err := eventWalker.WalkFrom(req.EventID) + limited, err := eventWalker.WalkFrom(rc.req.EventID) if err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID) + util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID) } return result, limited } -func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) getEventIfVisible(eventID string) *gomatrixserverlib.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse - err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ + err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ EventIDs: []string{eventID}, }, &queryEventsRes) if err != nil { - util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") + util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") return nil } if len(queryEventsRes.Events) == 0 { - util.GetLogger(ctx).Infof("event does not exist") + util.GetLogger(rc.ctx).Infof("event does not exist") return nil // event does not exist } event := queryEventsRes.Events[0] @@ -268,19 +282,33 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA // TODO: This does not honour history_visibility // TODO: This does not honour m.room.create content var queryMembershipRes roomserver.QueryMembershipForUserResponse - err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ + err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ RoomID: event.RoomID(), - UserID: userID, + UserID: rc.userID, }, &queryMembershipRes) if err != nil { - util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") + util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") return nil } - if !queryMembershipRes.IsInRoom { - util.GetLogger(ctx).Infof("user not in room") - return nil + if queryMembershipRes.IsInRoom { + return &event } - return &event + if rc.req.AutoJoin { + var joinRes roomserver.PerformJoinResponse + rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{ + UserID: rc.userID, + Content: map[string]interface{}{}, + RoomIDOrAlias: event.RoomID(), + // TODO: Add server_names from linked room, currently this join will only work if the HS is already in the room + }, &joinRes) + if joinRes.Error != nil { + util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room") + return nil + } + return &event + } + util.GetLogger(rc.ctx).Infof("user not in room and auto_join disabled") + return nil } type walkInfo struct { diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index 17dd1d124..795ead6a6 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -353,6 +353,10 @@ func TestMSC2836(t *testing.T) { }) } +// TODO: TestMSC2836TerminatesLoops (short and long) +// TODO: TestMSC2836UnknownEventsSkipped +// TODO: TestMSC2836SkipEventIfNotInRoom + func runServer(t *testing.T, router *mux.Router) func() { t.Helper() externalServ := &http.Server{