From 8fb8d5a7437b51ab134a5cbbba3d104752758635 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Mon, 12 Jun 2023 08:45:23 +0100 Subject: [PATCH] Use UserID in membership query API to reduce senderID code duplication --- clientapi/routing/account_data.go | 11 +-- clientapi/routing/aliases.go | 13 +--- clientapi/routing/directory.go | 6 +- clientapi/routing/membership.go | 75 ++++++--------------- clientapi/routing/redaction.go | 6 +- clientapi/routing/sendtyping.go | 11 +-- clientapi/routing/server_notices.go | 11 +-- clientapi/routing/state.go | 24 ++----- roomserver/api/api.go | 1 + roomserver/api/query.go | 10 +-- roomserver/internal/perform/perform_join.go | 14 ++-- roomserver/internal/query/query.go | 29 ++++++-- setup/mscs/msc2836/msc2836.go | 21 ++---- setup/mscs/msc2836/msc2836_test.go | 2 +- syncapi/routing/context.go | 9 +-- syncapi/routing/memberships.go | 11 +-- syncapi/routing/messages.go | 8 +-- 17 files changed, 84 insertions(+), 178 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index a2c1596e0..81afc3b13 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -145,23 +145,16 @@ func SaveReadMarker( userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("userID for this device is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown("failed to find senderID for this user"), - } - } // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index dfc8325de..2d6b72d3e 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -55,14 +55,7 @@ func GetAliases( visibility = content.HistoryVisibility } if visibility != spec.WorldReadable { - fullUserID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -70,8 +63,8 @@ func GetAliases( } } queryReq := api.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *deviceUserID, } var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index e9fc8e466..f01e24eca 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -314,14 +314,14 @@ func SetVisibility( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - fullUserID, err := spec.NewUserID(dev.UserID, true) + deviceUserID, err := spec.NewUserID(dev.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("userID for this device is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -329,7 +329,7 @@ func SetVisibility( } } - resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index e23abf147..03e85edbf 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -57,14 +57,14 @@ func SendBan( } } - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -72,7 +72,7 @@ func SendBan( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -148,14 +148,14 @@ func SendKick( } } - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -163,7 +163,7 @@ func SendKick( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -187,17 +187,10 @@ func SendKick( JSON: spec.BadJSON("body userID is invalid"), } } - bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.NotFound("body userID has no matching senderID"), - } - } var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: bodySenderID, + RoomID: roomID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -229,22 +222,15 @@ func SendUnban( } } - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } - errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -256,17 +242,10 @@ func SendUnban( JSON: spec.BadJSON("body userID is invalid"), } } - bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.NotFound("body userID has no matching senderID"), - } - } var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: bodySenderID, + RoomID: roomID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -317,22 +296,15 @@ func SendInvite( } } - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } - errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -568,11 +540,11 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, senderID spec.SenderID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse { var membershipRes roomserverAPI.QueryMembershipForUserResponse err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") @@ -597,25 +569,18 @@ func SendForget( ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) - fullUserID, err := spec.NewUserID(device.UserID, true) + deviceUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } var membershipRes roomserverAPI.QueryMembershipForUserResponse membershipReq := roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *deviceUserID, } err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 4d7f9a27a..94b6a90f1 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -47,14 +47,14 @@ func SendRedaction( txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { - fullUserID, userIDErr := spec.NewUserID(device.UserID, true) + deviceUserID, userIDErr := spec.NewUserID(device.UserID, true) if userIDErr != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("userID doesn't have power level to redact"), } } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) if queryErr != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -62,7 +62,7 @@ func SendRedaction( } } - resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index a43fd162e..979bced3b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -43,14 +43,7 @@ func SendTyping( } } - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + deviceUserID, err := spec.NewUserID(userID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -59,7 +52,7 @@ func SendTyping( } // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index b1f2e63f1..7006ced46 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -188,14 +188,7 @@ func SendServerNotice( } } else { // we've found a room in common, check the membership - fullUserID, err := spec.NewUserID(r.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + deviceUserID, err := spec.NewUserID(r.UserID, true) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -205,7 +198,7 @@ func SendServerNotice( roomID = commonRooms[0] membershipRes := api.QueryMembershipForUserResponse{} - err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID}, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") return util.JSONResponse{ diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 8eea2f308..e3a209b6e 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -107,17 +107,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a JSON: spec.Unknown("Device UserID is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device") - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("Unable to find senderID for user"), - } - } err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -289,19 +281,11 @@ func OnIncomingStateTypeRequest( JSON: spec.Unknown("Device UserID is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device") - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("Unable to find senderID for user"), - } - } // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 61562b305..cbc4a8931 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -228,6 +228,7 @@ type FederationRoomserverAPI interface { // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error + QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 891797557..684a5b0e3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct { // QueryMembershipForUserRequest is a request to QueryMembership type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from - RoomID string `json:"room_id"` + RoomID string // ID of the user for whom membership is requested - SenderID spec.SenderID `json:"user_id"` + UserID spec.UserID } // QueryMembershipForUserResponse is a response to QueryMembership @@ -492,12 +492,8 @@ type MembershipQuerier struct { } func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { - req := QueryMembershipForUserRequest{ - RoomID: roomID.String(), - SenderID: senderID, - } res := QueryMembershipForUserResponse{} - err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) + err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res) membership := "" if err == nil { diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 0052483ab..83c3b7c3e 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID( } // Get the domain part of the room ID. - _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) + roomID, err := spec.NewRoomID(req.RoomIDOrAlias) if err != nil { return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} } @@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if !r.Cfg.Matrix.IsLocalServerName(domain) { - req.ServerNames = append(req.ServerNames, domain) + if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { + req.ServerNames = append(req.ServerNames, roomID.Domain()) } // Prepare the template for the join event. @@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID( // a member of the room. This is best-effort (as in we won't // fail if we can't find the existing membership) because there // is really no harm in just sending another membership event. - membershipReq := &api.QueryMembershipForUserRequest{ - RoomID: req.RoomIDOrAlias, - SenderID: senderID, - } membershipRes := &api.QueryMembershipForUserResponse{} - _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) + _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes) // If we haven't already joined the room then send an event // into the room changing our membership status. @@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if r.Cfg.Matrix.IsLocalServerName(domain) { + if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 89a64dc9f..caea6b526 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID( return nil } -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *Queryer) QueryMembershipForUser( +// QueryMembershipForSenderID implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForSenderID( ctx context.Context, - request *api.QueryMembershipForUserRequest, + roomID spec.RoomID, + senderID spec.SenderID, response *api.QueryMembershipForUserResponse, ) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { return err } @@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if err != nil { return err } @@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser( return err } +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) + if err != nil { + return err + } + + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) +} + // QueryMembershipAtEvent returns the known memberships at a given event. // If the state before an event is not known, an empty list will be returned // for that event instead. diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index fdf3d0799..d3f1c9dd2 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -154,8 +154,7 @@ type reqCtx struct { rsAPI roomserver.RoomserverInternalAPI db Database req *EventRelationshipRequest - userID string - senderID spec.SenderID + userID spec.UserID roomVersion gomatrixserverlib.RoomVersion // federated request args @@ -181,18 +180,10 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), relation.RoomID, *userID) - if err != nil { - return util.JSONResponse{ - Code: 400, - JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), - } - } rc := reqCtx{ ctx: req.Context(), req: relation, - userID: device.UserID, - senderID: senderID, + userID: *userID, rsAPI: rsAPI, fsAPI: fsAPI, isFederatedRequest: false, @@ -353,8 +344,8 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent // check the user is joined to that room var queryMemRes roomserver.QueryMembershipForUserResponse err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: rc.senderID, + RoomID: roomID, + UserID: rc.userID, }, &queryMemRes) if err != nil { logger.WithError(err).Warn("failed to query membership for user in room") @@ -554,8 +545,8 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool { // TODO: This does not honour m.room.create content var queryMembershipRes roomserver.QueryMembershipForUserResponse err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: event.RoomID(), - SenderID: rc.senderID, + RoomID: event.RoomID(), + UserID: rc.userID, }, &queryMembershipRes) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 040e6d958..e32d6a9f2 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -544,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver } func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { - rooms := r.userToJoinedRooms[string(req.SenderID)] + rooms := r.userToJoinedRooms[req.UserID.String()] for _, roomID := range rooms { if roomID == req.RoomID { res.IsInRoom = true diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 1a8d33e38..55fd3c5a2 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -92,16 +92,9 @@ func Context( JSON: spec.InvalidParam("Device UserID is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.Unknown("SenderID for this device is unknown"), - } - } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID} + membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") return util.JSONResponse{ diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index bd2b37a14..cf6769ba4 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -66,16 +66,9 @@ func GetMemberships( JSON: spec.InvalidParam("Device UserID is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.Unknown("SenderID for this device is unknown"), - } - } queryReq := api.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *userID, } var queryRes api.QueryMembershipForUserResponse diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 91cf289f2..6784a27bd 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -300,13 +300,9 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. if err != nil { return resp, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return resp, err - } req := api.QueryMembershipForUserRequest{ - RoomID: roomID, - SenderID: senderID, + RoomID: roomID, + UserID: *fullUserID, } if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { return api.QueryMembershipForUserResponse{}, err