Add auto_join for local rooms

This commit is contained in:
Kegan Dougal 2020-11-05 16:04:43 +00:00
parent c320261ca2
commit 8a7ccd5d1d
2 changed files with 59 additions and 27 deletions

View file

@ -44,6 +44,7 @@ type EventRelationshipRequest struct {
IncludeChildren *bool `json:"include_children"` IncludeChildren *bool `json:"include_children"`
Direction string `json:"direction"` Direction string `json:"direction"`
Batch string `json:"batch"` Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"`
} }
func (r *EventRelationshipRequest) applyDefaults() { func (r *EventRelationshipRequest) applyDefaults() {
@ -104,6 +105,13 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us
return nil return nil
} }
type reqCtx struct {
ctx context.Context
rsAPI roomserver.RoomserverInternalAPI
req *EventRelationshipRequest
userID string
}
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse {
var relation EventRelationshipRequest var relation EventRelationshipRequest
@ -118,9 +126,15 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
relation.applyDefaults() relation.applyDefaults()
var res EventRelationshipResponse var res EventRelationshipResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent
rc := reqCtx{
ctx: req.Context(),
req: &relation,
userID: device.UserID,
rsAPI: rsAPI,
}
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue. // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID) event := rc.getEventIfVisible(relation.EventID)
if event == nil { if event == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 403, Code: 403,
@ -132,7 +146,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
returnEvents = append(returnEvents, event) returnEvents = append(returnEvents, event)
if *relation.IncludeParent { if *relation.IncludeParent {
if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil { if parentEvent := rc.includeParent(event); parentEvent != nil {
returnEvents = append(returnEvents, parentEvent) returnEvents = append(returnEvents, parentEvent)
} }
} }
@ -140,7 +154,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
if *relation.IncludeChildren { if *relation.IncludeChildren {
remaining := relation.Limit - len(returnEvents) remaining := relation.Limit - len(returnEvents)
if remaining > 0 { if remaining > 0 {
children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID) children, resErr := rc.includeChildren(db, event.EventID(), remaining, *relation.RecentFirst)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -157,7 +171,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
} }
var events []*gomatrixserverlib.HeaderedEvent var events []*gomatrixserverlib.HeaderedEvent
events, walkLimited = walkThread( events, walkLimited = walkThread(
req.Context(), db, rsAPI, device.UserID, &relation, included, remaining, req.Context(), db, &rc, included, remaining,
) )
returnEvents = append(returnEvents, events...) returnEvents = append(returnEvents, events...)
} }
@ -176,27 +190,27 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
// If include_parent: true and there is a valid m.relationship field in the event, // If include_parent: true and there is a valid m.relationship field in the event,
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) { func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
parentID, _, _ := parentChildEventIDs(event) parentID, _, _ := parentChildEventIDs(event)
if parentID == "" { if parentID == "" {
return nil return nil
} }
return getEventIfVisible(ctx, rsAPI, parentID, userID) return rc.getEventIfVisible(parentID)
} }
// If include_children: true, lookup all events which have event_id as an m.relationship // If include_children: true, lookup all events which have event_id as an m.relationship
// Apply history visibility checks to all these events and add the ones which pass into the response array, // Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit. // honouring the recent_first flag and the limit.
func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
children, err := db.ChildrenForParent(ctx, parentID, constRelType, recentFirst) children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent") util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
var childEvents []*gomatrixserverlib.HeaderedEvent var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children { for _, child := range children {
childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID) childEvent := rc.getEventIfVisible(child.EventID)
if childEvent != nil { if childEvent != nil {
childEvents = append(childEvents, childEvent) childEvents = append(childEvents, childEvent)
} }
@ -211,16 +225,16 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI
// honouring the limit, max_depth and max_breadth values according to the following rules // honouring the limit, max_depth and max_breadth values according to the following rules
// nolint: unparam // nolint: unparam
func walkThread( func walkThread(
ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int, ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) { ) ([]*gomatrixserverlib.HeaderedEvent, bool) {
if req.Direction != "down" { if rc.req.Direction != "down" {
util.GetLogger(ctx).Error("not implemented: direction=up") util.GetLogger(ctx).Error("not implemented: direction=up")
return nil, false return nil, false
} }
var result []*gomatrixserverlib.HeaderedEvent var result []*gomatrixserverlib.HeaderedEvent
eventWalker := walker{ eventWalker := walker{
ctx: ctx, ctx: ctx,
req: req, req: rc.req,
db: db, db: db,
fn: func(wi *walkInfo) bool { fn: func(wi *walkInfo) bool {
// If already processed event, skip. // If already processed event, skip.
@ -234,7 +248,7 @@ func walkThread(
} }
// Process the event. // Process the event.
event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID) event := rc.getEventIfVisible(wi.EventID)
if event != nil { if event != nil {
result = append(result, event) result = append(result, event)
} }
@ -242,24 +256,24 @@ func walkThread(
return false return false
}, },
} }
limited, err := eventWalker.WalkFrom(req.EventID) limited, err := eventWalker.WalkFrom(rc.req.EventID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID) util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID)
} }
return result, limited return result, limited
} }
func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) getEventIfVisible(eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse var queryEventsRes roomserver.QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
EventIDs: []string{eventID}, EventIDs: []string{eventID},
}, &queryEventsRes) }, &queryEventsRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
return nil return nil
} }
if len(queryEventsRes.Events) == 0 { if len(queryEventsRes.Events) == 0 {
util.GetLogger(ctx).Infof("event does not exist") util.GetLogger(rc.ctx).Infof("event does not exist")
return nil // event does not exist return nil // event does not exist
} }
event := queryEventsRes.Events[0] event := queryEventsRes.Events[0]
@ -268,19 +282,33 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
// TODO: This does not honour history_visibility // TODO: This does not honour history_visibility
// TODO: This does not honour m.room.create content // TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse var queryMembershipRes roomserver.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
UserID: userID, UserID: rc.userID,
}, &queryMembershipRes) }, &queryMembershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
return nil return nil
} }
if !queryMembershipRes.IsInRoom { if queryMembershipRes.IsInRoom {
util.GetLogger(ctx).Infof("user not in room") return &event
return nil
} }
return &event if rc.req.AutoJoin {
var joinRes roomserver.PerformJoinResponse
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
UserID: rc.userID,
Content: map[string]interface{}{},
RoomIDOrAlias: event.RoomID(),
// TODO: Add server_names from linked room, currently this join will only work if the HS is already in the room
}, &joinRes)
if joinRes.Error != nil {
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room")
return nil
}
return &event
}
util.GetLogger(rc.ctx).Infof("user not in room and auto_join disabled")
return nil
} }
type walkInfo struct { type walkInfo struct {

View file

@ -353,6 +353,10 @@ func TestMSC2836(t *testing.T) {
}) })
} }
// TODO: TestMSC2836TerminatesLoops (short and long)
// TODO: TestMSC2836UnknownEventsSkipped
// TODO: TestMSC2836SkipEventIfNotInRoom
func runServer(t *testing.T, router *mux.Router) func() { func runServer(t *testing.T, router *mux.Router) func() {
t.Helper() t.Helper()
externalServ := &http.Server{ externalServ := &http.Server{