diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 98b9a3ad6..dd8e5381f 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -79,17 +79,19 @@ func InviteV2( } input := HandleInviteInput{ - Context: httpReq.Context(), - RoomVersion: inviteReq.RoomVersion(), - RoomID: roomID, - EventID: eventID, - InvitedUser: *invitedUser, - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - Verifier: keys, - InviteQuerier: rsAPI, - InviteEvent: inviteReq.Event(), - StrippedState: inviteReq.InviteRoomState(), + Context: httpReq.Context(), + RoomVersion: inviteReq.RoomVersion(), + RoomID: roomID, + EventID: eventID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + InviteQuerier: rsAPI, + MembershipQuerier: MembershipQuerier{roomserver: rsAPI}, + GenerateStrippedState: rsAPI.GenerateInviteStrippedState, + InviteEvent: inviteReq.Event(), + StrippedState: inviteReq.InviteRoomState(), } event, jsonErr := handleInvite(input, rsAPI) if err != nil { @@ -162,17 +164,19 @@ func InviteV1( } input := HandleInviteInput{ - Context: httpReq.Context(), - RoomVersion: roomVer, - RoomID: roomID, - EventID: eventID, - InvitedUser: *invitedUser, - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - Verifier: keys, - InviteQuerier: rsAPI, - InviteEvent: event, - StrippedState: strippedState, + Context: httpReq.Context(), + RoomVersion: roomVer, + RoomID: roomID, + EventID: eventID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + InviteQuerier: rsAPI, + MembershipQuerier: MembershipQuerier{roomserver: rsAPI}, + GenerateStrippedState: rsAPI.GenerateInviteStrippedState, + InviteEvent: event, + StrippedState: strippedState, } event, jsonErr := handleInvite(input, rsAPI) if err != nil { @@ -231,24 +235,23 @@ func handleInvite(input HandleInviteInput, rsAPI api.FederationRoomserverAPI) (g // TODO: Migrate to GMSL -type InviteQuerier interface { +type RoomQuerier interface { IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) - // TODO: get rid of fclient references if possible - GenerateInviteStrippedState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU) ([]fclient.InviteV2StrippedState, error) - // TODO: Change this api shape dramatically for gmsl - QueryMembershipForUser(ctx context.Context, req *api.QueryMembershipForUserRequest, res *api.QueryMembershipForUserResponse) error } type HandleInviteInput struct { - Context context.Context - RoomVersion gomatrixserverlib.RoomVersion - RoomID spec.RoomID - EventID string - InvitedUser spec.UserID - KeyID gomatrixserverlib.KeyID - PrivateKey ed25519.PrivateKey - Verifier gomatrixserverlib.JSONVerifier - InviteQuerier InviteQuerier + Context context.Context + RoomVersion gomatrixserverlib.RoomVersion + RoomID spec.RoomID + EventID string + InvitedUser spec.UserID + KeyID gomatrixserverlib.KeyID + PrivateKey ed25519.PrivateKey + Verifier gomatrixserverlib.JSONVerifier + InviteQuerier RoomQuerier + MembershipQuerier MembershipQuerier + // TODO: get rid of fclient references if possible + GenerateStrippedState func(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU) ([]fclient.InviteV2StrippedState, error) InviteEvent gomatrixserverlib.PDU StrippedState []fclient.InviteV2StrippedState @@ -333,7 +336,7 @@ func processInvite( StateKey: "", }) } - if is, err := input.InviteQuerier.GenerateInviteStrippedState(input.Context, input.RoomID, stateWanted, inviteEvent); err == nil { + if is, err := input.GenerateStrippedState(input.Context, input.RoomID, stateWanted, inviteEvent); err == nil { inviteState = is } else { util.GetLogger(input.Context).WithError(err).Error("failed querying known room") @@ -365,17 +368,13 @@ func processInvite( } if isKnownRoom { - req := api.QueryMembershipForUserRequest{ - RoomID: input.RoomID.String(), - UserID: input.InvitedUser.String(), - } - res := api.QueryMembershipForUserResponse{} - err = input.InviteQuerier.QueryMembershipForUser(input.Context, &req, &res) + membership, err := input.MembershipQuerier.CurrentMembership(input.Context, input.RoomID, input.InvitedUser) if err != nil { util.GetLogger(input.Context).WithError(err).Error("failed getting user membership") return nil, spec.InternalServerError{} + } - isAlreadyJoined := (res.Membership == spec.Join) + isAlreadyJoined := (membership == spec.Join) if isAlreadyJoined { // If the user is joined to the room then that takes precedence over this diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 7ec69e333..b832bc9f3 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -50,24 +50,6 @@ func (r *Inviter) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, er return (info != nil && !info.IsStub()), nil } -func (r *Inviter) generateInviteStrippedState( - ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []fclient.InviteV2StrippedState, -) (*types.RoomInfo, []fclient.InviteV2StrippedState, error) { - info, err := r.DB.RoomInfo(ctx, roomID.String()) - if err != nil { - return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err) - } - strippedState := inviteState - if len(strippedState) == 0 && info != nil { - var is []fclient.InviteV2StrippedState - if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil { - strippedState = is - } - } - - return info, strippedState, nil -} - func (r *Inviter) GenerateInviteStrippedState( ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU, ) ([]fclient.InviteV2StrippedState, error) { @@ -286,6 +268,24 @@ func (r *Inviter) PerformInvite( return outputUpdates, nil } +func (r *Inviter) generateInviteStrippedState( + ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []fclient.InviteV2StrippedState, +) (*types.RoomInfo, []fclient.InviteV2StrippedState, error) { + info, err := r.DB.RoomInfo(ctx, roomID.String()) + if err != nil { + return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err) + } + strippedState := inviteState + if len(strippedState) == 0 && info != nil { + var is []fclient.InviteV2StrippedState + if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil { + strippedState = is + } + } + + return info, strippedState, nil +} + func buildInviteStrippedState( ctx context.Context, db storage.Database,