From 979738b2da2a3b800df86a875781c440a095b26c Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 8 Mar 2022 13:24:32 +0000 Subject: [PATCH] Get MSC2946 working for restricted rooms locally/over federation (#2260) * Get MSC2946 working for restricted rooms locally * Get MSC2946 working for restricted rooms over federation * Allow invited in addition to joined to enable child walking --- setup/mscs/msc2946/msc2946.go | 183 ++++++++++++++++++++++++---------- 1 file changed, 133 insertions(+), 50 deletions(-) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 7ab50c32e..7fb043366 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -199,13 +199,14 @@ func (w *walker) storePaginationCache(paginationToken string, cache paginationIn } type roomVisit struct { - roomID string - depth int - vias []string // vias to query this room by + roomID string + parentRoomID string + depth int + vias []string // vias to query this room by } func (w *walker) walk() util.JSONResponse { - if !w.authorised(w.rootRoomID) { + if authorised, _ := w.authorised(w.rootRoomID, ""); !authorised { if w.caller != nil { // CS API format return util.JSONResponse{ @@ -238,8 +239,9 @@ func (w *walker) walk() util.JSONResponse { w.paginationToken = tok // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms c.unvisited = append(c.unvisited, roomVisit{ - roomID: w.rootRoomID, - depth: 0, + roomID: w.rootRoomID, + parentRoomID: "", + depth: 0, }) } @@ -277,23 +279,8 @@ func (w *walker) walk() util.JSONResponse { // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { - // Get all `m.space.child` state events for this room - events, err := w.childReferences(rv.roomID) - if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") - continue - } - discoveredChildEvents = events - - pubRoom := w.publicRoomsChunk(rv.roomID) - - discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - RoomType: roomType, - ChildrenState: events, - }) - } else { + roomExists := w.roomExists(rv.roomID) + if !roomExists { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) @@ -312,6 +299,29 @@ func (w *walker) walk() util.JSONResponse { // as these children may be rooms we do know about. roomType = ConstCreateEventContentValueSpace } + } else if authorised, isJoinedOrInvited := w.authorised(rv.roomID, rv.parentRoomID); authorised { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) + if err != nil { + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") + continue + } + discoveredChildEvents = events + + pubRoom := w.publicRoomsChunk(rv.roomID) + + discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, + }) + // don't walk children if the user is not joined/invited to the space + if !isJoinedOrInvited { + continue + } + } else { + // room exists but user is not authorised + continue } // don't walk the children @@ -332,9 +342,10 @@ func (w *walker) walk() util.JSONResponse { ev := discoveredChildEvents[i] _ = json.Unmarshal(ev.Content, &spaceContent) unvisited = append(unvisited, roomVisit{ - roomID: ev.StateKey, - depth: rv.depth + 1, - vias: spaceContent.Via, + roomID: ev.StateKey, + parentRoomID: rv.roomID, + depth: rv.depth + 1, + vias: spaceContent.Via, }) } } @@ -465,25 +476,29 @@ func (w *walker) roomExists(roomID string) bool { } // authorised returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorised(roomID string) bool { +func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvited bool) { if w.caller != nil { - return w.authorisedUser(roomID) + return w.authorisedUser(roomID, parentRoomID) } - return w.authorisedServer(roomID) + return w.authorisedServer(roomID), false } // authorisedServer returns true iff the server is joined this room or the room is world_readable func (w *walker) authorisedServer(roomID string) bool { - // Check history visibility first + // Check history visibility / join rules first hisVisTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: "", } + joinRuleTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomJoinRules, + StateKey: "", + } var queryRoomRes roomserver.QueryCurrentStateResponse err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{ - hisVisTuple, + hisVisTuple, joinRuleTuple, }, }, &queryRoomRes) if err != nil { @@ -497,29 +512,46 @@ func (w *walker) authorisedServer(roomID string) bool { return true } } - // check if server is joined to the room - var queryRes fs.QueryJoinedHostServerNamesInRoomResponse - err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: roomID, - }, &queryRes) - if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom") - return false + + // check if this room is a restricted room and if so, we need to check if the server is joined to an allowed room ID + // in addition to the actual room ID (but always do the actual one first as it's quicker in the common case) + allowJoinedToRoomIDs := []string{roomID} + joinRuleEv := queryRoomRes.StateEvents[joinRuleTuple] + if joinRuleEv != nil { + allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) } - for _, srv := range queryRes.ServerNames { - if srv == w.serverName { - return true + + // check if server is joined to any allowed room + for _, allowedRoomID := range allowJoinedToRoomIDs { + var queryRes fs.QueryJoinedHostServerNamesInRoomResponse + err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: allowedRoomID, + }, &queryRes) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom") + continue + } + for _, srv := range queryRes.ServerNames { + if srv == w.serverName { + return true + } } } + return false } -// authorisedUser returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorisedUser(roomID string) bool { +// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable. +// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. +func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) { hisVisTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: "", } + joinRuleTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomJoinRules, + StateKey: "", + } roomMemberTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomMember, StateKey: w.caller.UserID, @@ -528,28 +560,79 @@ func (w *walker) authorisedUser(roomID string) bool { err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{ - hisVisTuple, roomMemberTuple, + hisVisTuple, joinRuleTuple, roomMemberTuple, }, }, &queryRes) if err != nil { util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState") - return false + return false, false } memberEv := queryRes.StateEvents[roomMemberTuple] - hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { - return true + return true, true } } + hisVisEv := queryRes.StateEvents[hisVisTuple] if hisVisEv != nil { hisVis, _ := hisVisEv.HistoryVisibility() if hisVis == "world_readable" { - return true + return true, false } } - return false + joinRuleEv := queryRes.StateEvents[joinRuleTuple] + if parentRoomID != "" && joinRuleEv != nil { + allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") + // check parent is in the allowed set + var allowed bool + for _, a := range allowedRoomIDs { + if parentRoomID == a { + allowed = true + break + } + } + if allowed { + // ensure caller is joined to the parent room + var queryRes2 roomserver.QueryCurrentStateResponse + err = w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ + RoomID: parentRoomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + roomMemberTuple, + }, + }, &queryRes2) + if err != nil { + util.GetLogger(w.ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room") + } else { + memberEv = queryRes2.StateEvents[roomMemberTuple] + if memberEv != nil { + membership, _ := memberEv.Membership() + if membership == gomatrixserverlib.Join { + return true, false + } + } + } + } + } + return false, false +} + +func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { + rule, _ := joinRuleEv.JoinRule() + if rule != "restricted" { + return nil + } + var jrContent gomatrixserverlib.JoinRuleContent + if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil { + util.GetLogger(w.ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID(), err) + return nil + } + for _, allow := range jrContent.Allow { + if allow.Type == allowType { + allows = append(allows, allow.RoomID) + } + } + return } // references returns all child references pointing to or from this room.