Add auto_join for local rooms

This commit is contained in:
Kegan Dougal 2020-11-05 16:04:43 +00:00
parent c320261ca2
commit 8a7ccd5d1d
2 changed files with 59 additions and 27 deletions

View file

@ -44,6 +44,7 @@ type EventRelationshipRequest struct {
IncludeChildren *bool `json:"include_children"`
Direction string `json:"direction"`
Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"`
}
func (r *EventRelationshipRequest) applyDefaults() {
@ -104,6 +105,13 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us
return nil
}
type reqCtx struct {
ctx context.Context
rsAPI roomserver.RoomserverInternalAPI
req *EventRelationshipRequest
userID string
}
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
var relation EventRelationshipRequest
@ -118,9 +126,15 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
relation.applyDefaults()
var res EventRelationshipResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent
rc := reqCtx{
ctx: req.Context(),
req: &relation,
userID: device.UserID,
rsAPI: rsAPI,
}
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID)
event := rc.getEventIfVisible(relation.EventID)
if event == nil {
return util.JSONResponse{
Code: 403,
@ -132,7 +146,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
returnEvents = append(returnEvents, event)
if *relation.IncludeParent {
if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil {
if parentEvent := rc.includeParent(event); parentEvent != nil {
returnEvents = append(returnEvents, parentEvent)
}
}
@ -140,7 +154,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
if *relation.IncludeChildren {
remaining := relation.Limit - len(returnEvents)
if remaining > 0 {
children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID)
children, resErr := rc.includeChildren(db, event.EventID(), remaining, *relation.RecentFirst)
if resErr != nil {
return *resErr
}
@ -157,7 +171,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
}
var events []*gomatrixserverlib.HeaderedEvent
events, walkLimited = walkThread(
req.Context(), db, rsAPI, device.UserID, &relation, included, remaining,
req.Context(), db, &rc, included, remaining,
)
returnEvents = append(returnEvents, events...)
}
@ -176,27 +190,27 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
// If include_parent: true and there is a valid m.relationship field in the event,
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) {
func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
parentID, _, _ := parentChildEventIDs(event)
if parentID == "" {
return nil
}
return getEventIfVisible(ctx, rsAPI, parentID, userID)
return rc.getEventIfVisible(parentID)
}
// If include_children: true, lookup all events which have event_id as an m.relationship
// Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit.
func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
children, err := db.ChildrenForParent(ctx, parentID, constRelType, recentFirst)
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent")
util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children {
childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID)
childEvent := rc.getEventIfVisible(child.EventID)
if childEvent != nil {
childEvents = append(childEvents, childEvent)
}
@ -211,16 +225,16 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI
// honouring the limit, max_depth and max_breadth values according to the following rules
// nolint: unparam
func walkThread(
ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int,
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
if req.Direction != "down" {
if rc.req.Direction != "down" {
util.GetLogger(ctx).Error("not implemented: direction=up")
return nil, false
}
var result []*gomatrixserverlib.HeaderedEvent
eventWalker := walker{
ctx: ctx,
req: req,
req: rc.req,
db: db,
fn: func(wi *walkInfo) bool {
// If already processed event, skip.
@ -234,7 +248,7 @@ func walkThread(
}
// Process the event.
event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID)
event := rc.getEventIfVisible(wi.EventID)
if event != nil {
result = append(result, event)
}
@ -242,24 +256,24 @@ func walkThread(
return false
},
}
limited, err := eventWalker.WalkFrom(req.EventID)
limited, err := eventWalker.WalkFrom(rc.req.EventID)
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID)
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID)
}
return result, limited
}
func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent {
func (rc *reqCtx) getEventIfVisible(eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
EventIDs: []string{eventID},
}, &queryEventsRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
return nil
}
if len(queryEventsRes.Events) == 0 {
util.GetLogger(ctx).Infof("event does not exist")
util.GetLogger(rc.ctx).Infof("event does not exist")
return nil // event does not exist
}
event := queryEventsRes.Events[0]
@ -268,20 +282,34 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
// TODO: This does not honour history_visibility
// TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(),
UserID: userID,
UserID: rc.userID,
}, &queryMembershipRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
return nil
}
if !queryMembershipRes.IsInRoom {
util.GetLogger(ctx).Infof("user not in room")
if queryMembershipRes.IsInRoom {
return &event
}
if rc.req.AutoJoin {
var joinRes roomserver.PerformJoinResponse
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
UserID: rc.userID,
Content: map[string]interface{}{},
RoomIDOrAlias: event.RoomID(),
// TODO: Add server_names from linked room, currently this join will only work if the HS is already in the room
}, &joinRes)
if joinRes.Error != nil {
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room")
return nil
}
return &event
}
util.GetLogger(rc.ctx).Infof("user not in room and auto_join disabled")
return nil
}
type walkInfo struct {
eventInfo

View file

@ -353,6 +353,10 @@ func TestMSC2836(t *testing.T) {
})
}
// TODO: TestMSC2836TerminatesLoops (short and long)
// TODO: TestMSC2836UnknownEventsSkipped
// TODO: TestMSC2836SkipEventIfNotInRoom
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{