diff --git a/clientapi/auth/login_publickey_ethereum_test.go b/clientapi/auth/login_publickey_ethereum_test.go index 12fae2654..73842f9a0 100644 --- a/clientapi/auth/login_publickey_ethereum_test.go +++ b/clientapi/auth/login_publickey_ethereum_test.go @@ -34,7 +34,7 @@ type loginContext struct { userInteractive *UserInteractive } -func createLoginContext(t *testing.T) *loginContext { +func createLoginContext(_ *testing.T) *loginContext { chainIds := []int{4} cfg := &config.ClientAPI{ diff --git a/clientapi/routing/register_publickey_test.go b/clientapi/routing/register_publickey_test.go index 688769e89..c51ee1846 100644 --- a/clientapi/routing/register_publickey_test.go +++ b/clientapi/routing/register_publickey_test.go @@ -40,7 +40,7 @@ type registerContext struct { userInteractive *auth.UserInteractive } -func createRegisterContext(t *testing.T) *registerContext { +func createRegisterContext(_ *testing.T) *registerContext { chainIds := []int{4} cfg := &config.ClientAPI{ @@ -173,8 +173,8 @@ func (*fakePublicKeyUserApi) QueryLoginToken(ctx context.Context, req *uapi.Quer func newRegistrationSession( t *testing.T, userId string, - cfg *config.ClientAPI, - userInteractive *auth.UserInteractive, + _ *config.ClientAPI, + _ *auth.UserInteractive, userAPI *fakePublicKeyUserApi, ) string { body := fmt.Sprintf(`{ diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 61520d50e..f2ff48175 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -45,6 +45,9 @@ const ( ConstCreateEventContentValueSpace = "m.space" ConstSpaceChildEventType = "m.space.child" ConstSpaceParentEventType = "m.space.parent" + ConstJoinRulePublic = "public" + ConstJoinRuleKnock = "knock" + ConstJoinRuleRestricted = "restricted" ) type MSC2946ClientResponse struct { @@ -479,7 +482,7 @@ func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvi return w.authorisedServer(roomID), false } -// authorisedServer returns true iff the server is joined this room or the room is world_readable +// authorisedServer returns true iff the server is joined this room or the room is world_readable, public, or knockable func (w *walker) authorisedServer(roomID string) bool { // Check history visibility / join rules first hisVisTuple := gomatrixserverlib.StateKeyTuple{ @@ -513,8 +516,21 @@ func (w *walker) authorisedServer(roomID string) bool { // in addition to the actual room ID (but always do the actual one first as it's quicker in the common case) allowJoinedToRoomIDs := []string{roomID} joinRuleEv := queryRoomRes.StateEvents[joinRuleTuple] + if joinRuleEv != nil { - allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) + rule, ruleErr := joinRuleEv.JoinRule() + if ruleErr != nil { + util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", roomID).Warn("failed to get join rule") + return false + } + + if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock { + return true + } + + if rule == ConstJoinRuleRestricted { + allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) + } } // check if server is joined to any allowed room @@ -537,7 +553,8 @@ func (w *walker) authorisedServer(roomID string) bool { return false } -// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable. +// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable +// or if the room has a public or knock join rule. // Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) { hisVisTuple := gomatrixserverlib.StateKeyTuple{ @@ -579,13 +596,20 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi } joinRuleEv := queryRes.StateEvents[joinRuleTuple] if parentRoomID != "" && joinRuleEv != nil { - allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") - // check parent is in the allowed set var allowed bool - for _, a := range allowedRoomIDs { - if parentRoomID == a { - allowed = true - break + rule, ruleErr := joinRuleEv.JoinRule() + if ruleErr != nil { + util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule") + } else if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock { + allowed = true + } else if rule == ConstJoinRuleRestricted { + allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") + // check parent is in the allowed set + for _, a := range allowedRoomIDs { + if parentRoomID == a { + allowed = true + break + } } } if allowed { @@ -615,7 +639,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { rule, _ := joinRuleEv.JoinRule() - if rule != "restricted" { + if rule != ConstJoinRuleRestricted { return nil } var jrContent gomatrixserverlib.JoinRuleContent