From 168140f82da7ad760744815a5e0a4e87546c4a60 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 8 Mar 2022 11:14:52 +0000 Subject: [PATCH] Get MSC2946 working for restricted rooms locally --- setup/mscs/msc2946/msc2946.go | 75 +++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 7ab50c32e..0c1ad7c37 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -199,13 +199,14 @@ func (w *walker) storePaginationCache(paginationToken string, cache paginationIn } type roomVisit struct { - roomID string - depth int - vias []string // vias to query this room by + roomID string + parentRoomID string + depth int + vias []string // vias to query this room by } func (w *walker) walk() util.JSONResponse { - if !w.authorised(w.rootRoomID) { + if !w.authorised(w.rootRoomID, "") { if w.caller != nil { // CS API format return util.JSONResponse{ @@ -238,8 +239,9 @@ func (w *walker) walk() util.JSONResponse { w.paginationToken = tok // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms c.unvisited = append(c.unvisited, roomVisit{ - roomID: w.rootRoomID, - depth: 0, + roomID: w.rootRoomID, + parentRoomID: "", + depth: 0, }) } @@ -277,7 +279,7 @@ func (w *walker) walk() util.JSONResponse { // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + if w.roomExists(rv.roomID) && w.authorised(rv.roomID, rv.parentRoomID) { // Get all `m.space.child` state events for this room events, err := w.childReferences(rv.roomID) if err != nil { @@ -332,9 +334,10 @@ func (w *walker) walk() util.JSONResponse { ev := discoveredChildEvents[i] _ = json.Unmarshal(ev.Content, &spaceContent) unvisited = append(unvisited, roomVisit{ - roomID: ev.StateKey, - depth: rv.depth + 1, - vias: spaceContent.Via, + roomID: ev.StateKey, + parentRoomID: rv.roomID, + depth: rv.depth + 1, + vias: spaceContent.Via, }) } } @@ -465,9 +468,9 @@ func (w *walker) roomExists(roomID string) bool { } // authorised returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorised(roomID string) bool { +func (w *walker) authorised(roomID, parentRoomID string) bool { if w.caller != nil { - return w.authorisedUser(roomID) + return w.authorisedUser(roomID, parentRoomID) } return w.authorisedServer(roomID) } @@ -514,12 +517,17 @@ func (w *walker) authorisedServer(roomID string) bool { return false } -// authorisedUser returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorisedUser(roomID string) bool { +// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable. +// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. +func (w *walker) authorisedUser(roomID, parentRoomID string) bool { hisVisTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: "", } + joinRuleTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomJoinRules, + StateKey: "", + } roomMemberTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomMember, StateKey: w.caller.UserID, @@ -528,7 +536,7 @@ func (w *walker) authorisedUser(roomID string) bool { err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{ - hisVisTuple, roomMemberTuple, + hisVisTuple, joinRuleTuple, roomMemberTuple, }, }, &queryRes) if err != nil { @@ -536,19 +544,54 @@ func (w *walker) authorisedUser(roomID string) bool { return false } memberEv := queryRes.StateEvents[roomMemberTuple] - hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } + hisVisEv := queryRes.StateEvents[hisVisTuple] if hisVisEv != nil { hisVis, _ := hisVisEv.HistoryVisibility() if hisVis == "world_readable" { return true } } + joinRuleEv := queryRes.StateEvents[joinRuleTuple] + if parentRoomID != "" && joinRuleEv != nil { + rule, _ := joinRuleEv.JoinRule() + if rule == "restricted" { + var jrContent gomatrixserverlib.JoinRuleContent + if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil { + util.GetLogger(w.ctx).Warnf("failed to check join_rule on room %s: %s", roomID, err) + return false + } + // check the allow section + for _, allow := range jrContent.Allow { + if allow.Type == "m.room_membership" && allow.RoomID == parentRoomID { + // ensure caller is joined to the parent room + var queryRes2 roomserver.QueryCurrentStateResponse + err = w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ + RoomID: parentRoomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + roomMemberTuple, + }, + }, &queryRes2) + if err != nil { + util.GetLogger(w.ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room") + continue + } + memberEv = queryRes2.StateEvents[roomMemberTuple] + if memberEv != nil { + membership, _ := memberEv.Membership() + if membership == gomatrixserverlib.Join { + return true + } + } + } + } + } + } return false }