diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 99ec53719..ec1cd6abb 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -68,6 +68,7 @@ type InputRoomEventsAPI interface { type QuerySenderIDAPI interface { // Accepts either roomID or alias QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 7a12bc2cf..329395e63 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership( ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 15d226a74..d249b2896 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1035,7 +1035,16 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query return nil } -func (r *Queryer) QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { // TODO: implement this properly with pseudoIDs return userID.String(), nil } + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) { + // TODO: implement this properly with pseudoIDs + userID, err := spec.NewUserID(senderID, true) + if err != nil { + return spec.UserID{}, err + } + return *userID, err +}