Get MSC2946 working for restricted rooms locally

This commit is contained in:
Kegan Dougal 2022-03-08 11:14:52 +00:00
parent 67de4dbd0c
commit 168140f82d

View file

@ -199,13 +199,14 @@ func (w *walker) storePaginationCache(paginationToken string, cache paginationIn
} }
type roomVisit struct { type roomVisit struct {
roomID string roomID string
depth int parentRoomID string
vias []string // vias to query this room by depth int
vias []string // vias to query this room by
} }
func (w *walker) walk() util.JSONResponse { func (w *walker) walk() util.JSONResponse {
if !w.authorised(w.rootRoomID) { if !w.authorised(w.rootRoomID, "") {
if w.caller != nil { if w.caller != nil {
// CS API format // CS API format
return util.JSONResponse{ return util.JSONResponse{
@ -238,8 +239,9 @@ func (w *walker) walk() util.JSONResponse {
w.paginationToken = tok w.paginationToken = tok
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
c.unvisited = append(c.unvisited, roomVisit{ c.unvisited = append(c.unvisited, roomVisit{
roomID: w.rootRoomID, roomID: w.rootRoomID,
depth: 0, parentRoomID: "",
depth: 0,
}) })
} }
@ -277,7 +279,7 @@ func (w *walker) walk() util.JSONResponse {
// If we know about this room and the caller is authorised (joined/world_readable) then pull // If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally // events locally
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { if w.roomExists(rv.roomID) && w.authorised(rv.roomID, rv.parentRoomID) {
// Get all `m.space.child` state events for this room // Get all `m.space.child` state events for this room
events, err := w.childReferences(rv.roomID) events, err := w.childReferences(rv.roomID)
if err != nil { if err != nil {
@ -332,9 +334,10 @@ func (w *walker) walk() util.JSONResponse {
ev := discoveredChildEvents[i] ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent) _ = json.Unmarshal(ev.Content, &spaceContent)
unvisited = append(unvisited, roomVisit{ unvisited = append(unvisited, roomVisit{
roomID: ev.StateKey, roomID: ev.StateKey,
depth: rv.depth + 1, parentRoomID: rv.roomID,
vias: spaceContent.Via, depth: rv.depth + 1,
vias: spaceContent.Via,
}) })
} }
} }
@ -465,9 +468,9 @@ func (w *walker) roomExists(roomID string) bool {
} }
// authorised returns true iff the user is joined this room or the room is world_readable // authorised returns true iff the user is joined this room or the room is world_readable
func (w *walker) authorised(roomID string) bool { func (w *walker) authorised(roomID, parentRoomID string) bool {
if w.caller != nil { if w.caller != nil {
return w.authorisedUser(roomID) return w.authorisedUser(roomID, parentRoomID)
} }
return w.authorisedServer(roomID) return w.authorisedServer(roomID)
} }
@ -514,12 +517,17 @@ func (w *walker) authorisedServer(roomID string) bool {
return false return false
} }
// authorisedUser returns true iff the user is joined this room or the room is world_readable // authorisedUser returns true iff the user is invited/joined this room or the room is world_readable.
func (w *walker) authorisedUser(roomID string) bool { // Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true.
func (w *walker) authorisedUser(roomID, parentRoomID string) bool {
hisVisTuple := gomatrixserverlib.StateKeyTuple{ hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomHistoryVisibility, EventType: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: "", StateKey: "",
} }
joinRuleTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomJoinRules,
StateKey: "",
}
roomMemberTuple := gomatrixserverlib.StateKeyTuple{ roomMemberTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomMember, EventType: gomatrixserverlib.MRoomMember,
StateKey: w.caller.UserID, StateKey: w.caller.UserID,
@ -528,7 +536,7 @@ func (w *walker) authorisedUser(roomID string) bool {
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID, RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{ StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, roomMemberTuple, hisVisTuple, joinRuleTuple, roomMemberTuple,
}, },
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
@ -536,19 +544,54 @@ func (w *walker) authorisedUser(roomID string) bool {
return false return false
} }
memberEv := queryRes.StateEvents[roomMemberTuple] memberEv := queryRes.StateEvents[roomMemberTuple]
hisVisEv := queryRes.StateEvents[hisVisTuple]
if memberEv != nil { if memberEv != nil {
membership, _ := memberEv.Membership() membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
return true return true
} }
} }
hisVisEv := queryRes.StateEvents[hisVisTuple]
if hisVisEv != nil { if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility() hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" { if hisVis == "world_readable" {
return true return true
} }
} }
joinRuleEv := queryRes.StateEvents[joinRuleTuple]
if parentRoomID != "" && joinRuleEv != nil {
rule, _ := joinRuleEv.JoinRule()
if rule == "restricted" {
var jrContent gomatrixserverlib.JoinRuleContent
if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil {
util.GetLogger(w.ctx).Warnf("failed to check join_rule on room %s: %s", roomID, err)
return false
}
// check the allow section
for _, allow := range jrContent.Allow {
if allow.Type == "m.room_membership" && allow.RoomID == parentRoomID {
// ensure caller is joined to the parent room
var queryRes2 roomserver.QueryCurrentStateResponse
err = w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: parentRoomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
roomMemberTuple,
},
}, &queryRes2)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room")
continue
}
memberEv = queryRes2.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join {
return true
}
}
}
}
}
}
return false return false
} }