From 919d9942b6ecad047b0707d0fca34306767f587c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 23 May 2022 12:13:34 +0100 Subject: [PATCH] Add `QueryRestrictedJoinAllowed` --- roomserver/api/api.go | 1 + roomserver/api/api_trace.go | 10 ++++ roomserver/api/query.go | 9 ++++ roomserver/internal/query/query.go | 74 ++++++++++++++++++++++++++++++ roomserver/inthttp/client.go | 11 +++++ roomserver/inthttp/server.go | 13 ++++++ 6 files changed, 118 insertions(+) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 80e7aed64..6d28a79e2 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -134,6 +134,7 @@ type ClientRoomserverAPI interface { QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error QueryRoomVersionCapabilities(ctx context.Context, req *QueryRoomVersionCapabilitiesRequest, res *QueryRoomVersionCapabilitiesResponse) error + QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 711324644..92c5c1b1d 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -354,6 +354,16 @@ func (t *RoomserverInternalAPITrace) QueryAuthChain( return err } +func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( + ctx context.Context, + request *QueryRestrictedJoinAllowedRequest, + response *QueryRestrictedJoinAllowedResponse, +) error { + err := t.Impl.QueryRestrictedJoinAllowed(ctx, request, response) + util.GetLogger(ctx).WithError(err).Infof("QueryRestrictedJoinAllowed req=%+v res=%+v", js(request), js(response)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index afafb87c3..e6460ab21 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -348,6 +348,15 @@ type QueryServerBannedFromRoomResponse struct { Banned bool `json:"banned"` } +type QueryRestrictedJoinAllowedRequest struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` +} + +type QueryRestrictedJoinAllowedResponse struct { + Allowed bool `json:"allowed"` +} + // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { se := make(map[string]string) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index d25bdc378..103f1cb86 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,6 +16,7 @@ package query import ( "context" + "encoding/json" "errors" "fmt" @@ -757,3 +758,76 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq res.AuthChain = hchain return nil } + +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { + // Look up if we know anything about the room. If it doesn't exist + // or is a stub entry then we can't do anything. + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + if err != nil { + return fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if roomInfo == nil || roomInfo.IsStub { + return fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) + } + // If the room version doesn't allow restricted joins then don't + // try to process any further. + allowRestrictedJoins, err := roomInfo.RoomVersion.AllowRestrictedJoinsInEventAuth() + if err != nil { + return fmt.Errorf("roomInfo.RoomVersion.AllowRestrictedJoinsInEventAuth: %w", err) + } else if !allowRestrictedJoins { + return nil + } + // Get the join rules to work out if the join rule is "restricted". + joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomJoinRules, "") + if err != nil { + return fmt.Errorf("r.DB.GetStateEvent: %w", err) + } + var joinRules gomatrixserverlib.JoinRuleContent + if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + // If the join rule isn't "restricted" then there's nothing more to do. + if joinRules.JoinRule != gomatrixserverlib.Restricted { + return nil + } + // Step through the join rules and see if the user matches any of them. + for _, rule := range joinRules.Allow { + // We only understand "m.room_membership" rules at this point in + // time, so skip any rule that doesn't match those. + if rule.Type != gomatrixserverlib.MRoomMembership { + continue + } + // See if the room exists. If it doesn't exist or if it's a stub + // room entry then we can't check memberships. + targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID) + if err != nil { + continue + } + if targetRoomInfo == nil || targetRoomInfo.IsStub { + continue + } + // First of all work out if *we* are still in the room, otherwise + // it's possible that the memberships will be out of date. + isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID) + if err != nil { + continue + } + if !isIn { + // We aren't in the room, so we can no longer tell if the room + // memberships are up-to-date. + continue + } + // At this point we're happy that we are in the room, so now let's + // see if the target user is in the room. + _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) + if err != nil { + continue + } + // If the user is in the room then we will allow the membership. + if isIn { + res.Allowed = true + return nil + } + } + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 09358001b..7b10ae657 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -61,6 +61,7 @@ const ( RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" + RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" ) type httpRoomserverInternalAPI struct { @@ -557,6 +558,16 @@ func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } +func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( + ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") defer span.Finish() diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 9042e341b..ad4fdc469 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -472,4 +472,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed, + httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse { + request := api.QueryRestrictedJoinAllowedRequest{} + response := api.QueryRestrictedJoinAllowedResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) }