diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 0fe0f4f27..b56b41bdd 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -230,6 +230,31 @@ func (r *Queryer) QueryMembershipForSenderID( senderID spec.SenderID, response *api.QueryMembershipForUserResponse, ) error { + return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response) +} + +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) + if err != nil { + return err + } + + return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response) +} + +// Query membership information for provided sender ID and room ID +// +// If sender ID is nil, then act as if the provided sender is not a member of the room. +func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error { info, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { return err @@ -240,7 +265,11 @@ func (r *Queryer) QueryMembershipForSenderID( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) + if senderID == nil { + return nil + } + + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID) if err != nil { return err } @@ -268,24 +297,6 @@ func (r *Queryer) QueryMembershipForSenderID( return err } -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *Queryer) QueryMembershipForUser( - ctx context.Context, - request *api.QueryMembershipForUserRequest, - response *api.QueryMembershipForUserResponse, -) error { - roomID, err := spec.NewRoomID(request.RoomID) - if err != nil { - return err - } - senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) - if err != nil { - return err - } - - return r.QueryMembershipForSenderID(ctx, *roomID, *senderID, response) -} - // QueryMembershipAtEvent returns the known memberships at a given event. // If the state before an event is not known, an empty list will be returned // for that event instead.