diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 7eacf9cc9..a6477fdea 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -145,8 +145,23 @@ func SaveReadMarker( userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index f6603be8b..dfc8325de 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -55,9 +55,23 @@ func GetAliases( visibility = content.HistoryVisibility } if visibility != spec.WorldReadable { + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } queryReq := api.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: device.UserID, + RoomID: roomID, + SenderID: senderID, } var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 034296f45..516a158e1 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -314,30 +314,6 @@ func SetVisibility( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) - if resErr != nil { - return *resErr - } - - queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{ - RoomID: roomID, - StateToFetch: []gomatrixserverlib.StateKeyTuple{{ - EventType: spec.MRoomPowerLevels, - StateKey: "", - }}, - } - var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse - err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) - if err != nil || len(queryEventsRes.StateEvents) == 0 { - util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event - power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) fullUserID, err := spec.NewUserID(dev.UserID, true) if err != nil { return util.JSONResponse{ @@ -352,6 +328,31 @@ func SetVisibility( JSON: spec.Forbidden("userID doesn't have power level to change visibility"), } } + + resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + if resErr != nil { + return *resErr + } + + queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{ + RoomID: roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{{ + EventType: spec.MRoomPowerLevels, + StateKey: "", + }}, + } + var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse + err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + if err != nil || len(queryEventsRes.StateEvents) == 0 { + util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event + power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index fbf148264..6d215c826 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -29,10 +29,25 @@ func LeaveRoomByID( rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("device userID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("could not find senderID for this user"), + } + } + // Prepare to ask the roomserver to perform the room join. leaveReq := roomserverAPI.PerformLeaveRequest{ RoomID: roomID, - UserID: device.UserID, + Leaver: roomserverAPI.SenderUserIDPair{SenderID: senderID, UserID: *userID}, } leaveRes := roomserverAPI.PerformLeaveResponse{} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index d5e1ce1f6..e23abf147 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -57,15 +57,6 @@ func SendBan( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) - if errRes != nil { - return *errRes - } - - pl, errRes := getPowerlevels(req, rsAPI, roomID) - if errRes != nil { - return *errRes - } fullUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ @@ -80,6 +71,16 @@ func SendBan( JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), } } + + errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + if errRes != nil { + return *errRes + } + + pl, errRes := getPowerlevels(req, rsAPI, roomID) + if errRes != nil { + return *errRes + } allowedToBan := pl.UserLevel(senderID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ @@ -147,15 +148,6 @@ func SendKick( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) - if errRes != nil { - return *errRes - } - - pl, errRes := getPowerlevels(req, rsAPI, roomID) - if errRes != nil { - return *errRes - } fullUserID, err := spec.NewUserID(device.UserID, true) if err != nil { return util.JSONResponse{ @@ -170,6 +162,16 @@ func SendKick( JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), } } + + errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) + if errRes != nil { + return *errRes + } + + pl, errRes := getPowerlevels(req, rsAPI, roomID) + if errRes != nil { + return *errRes + } allowedToKick := pl.UserLevel(senderID) >= pl.Kick if !allowedToKick { return util.JSONResponse{ @@ -178,10 +180,24 @@ func SendKick( } } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } + bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.NotFound("body userID has no matching senderID"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: body.UserID, + RoomID: roomID, + SenderID: bodySenderID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -213,15 +229,44 @@ func SendUnban( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) if errRes != nil { return *errRes } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } + bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.NotFound("body userID has no matching senderID"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: body.UserID, + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + RoomID: roomID, + SenderID: bodySenderID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -272,7 +317,22 @@ func SendInvite( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) if errRes != nil { return *errRes } @@ -508,11 +568,11 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, senderID spec.SenderID, roomID string) *util.JSONResponse { var membershipRes roomserverAPI.QueryMembershipForUserResponse err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: userID, + RoomID: roomID, + SenderID: senderID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") @@ -536,12 +596,28 @@ func SendForget( ) util.JSONResponse { ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + var membershipRes roomserverAPI.QueryMembershipForUserResponse membershipReq := roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: device.UserID, + RoomID: roomID, + SenderID: senderID, } - err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") return util.JSONResponse{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 22474fc08..4d7f9a27a 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -47,7 +47,22 @@ func SendRedaction( txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + fullUserID, userIDErr := spec.NewUserID(device.UserID, true) + if userIDErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if queryErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) if resErr != nil { return *resErr } @@ -73,21 +88,6 @@ func SendRedaction( } } - fullUserID, userIDErr := spec.NewUserID(device.UserID, true) - if userIDErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if queryErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index c5b29297a..a43fd162e 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -43,8 +43,23 @@ func SendTyping( } } + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 06714ed1f..b1f2e63f1 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -52,6 +52,7 @@ type sendServerNoticeRequest struct { StateKey string `json:"state_key,omitempty"` } +// nolint:gocyclo // SendServerNotice sends a message to a specific user. It can only be invoked by an admin. func SendServerNotice( req *http.Request, @@ -187,9 +188,24 @@ func SendServerNotice( } } else { // we've found a room in common, check the membership + fullUserID, err := spec.NewUserID(r.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + roomID = commonRooms[0] membershipRes := api.QueryMembershipForUserResponse{} - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") return util.JSONResponse{ @@ -234,7 +250,7 @@ func SendServerNotice( ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{ - &types.HeaderedEvent{PDU: e}, + {PDU: e}, }, device.UserDomain(), cfgClient.Matrix.ServerName, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index d74002928..8eea2f308 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -99,9 +99,25 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if !worldReadable { // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: device.UserID, + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device") + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Unable to find senderID for user"), + } + } + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + RoomID: roomID, + SenderID: senderID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -265,11 +281,27 @@ func OnIncomingStateTypeRequest( // membershipRes will only be populated if the room is not world-readable. var membershipRes api.QueryMembershipForUserResponse if !worldReadable { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device") + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Unable to find senderID for user"), + } + } // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: device.UserID, + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + RoomID: roomID, + SenderID: senderID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") diff --git a/go.mod b/go.mod index 23da31ec1..af75040fe 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 2e8d258e5..c99d9a07a 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 h1:WkmHgdQvpPBxGo/huuCQVR9FWmdkKOcAuQHZBmcHK3g= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483 h1:T7y2bnp0l0JVGqegAmpOdkaENFiA7raSNSdYz0PXxaI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 8c2cbd6b2..53b77e4c7 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string { type RestrictedJoinAPI interface { CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) } @@ -238,15 +238,13 @@ type FederationRoomserverAPI interface { // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // the state and auth chain to return. QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error - // Query if we think we're still in a room. - QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error @@ -254,12 +252,6 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error - CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) - LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) - IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) StateQuerier() gomatrixserverlib.StateQuerier } diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 16b504957..852b64206 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -215,8 +215,10 @@ type OutputNewInviteEvent struct { type OutputRetireInviteEvent struct { // The ID of the "m.room.member" invite event. EventID string - // The target user ID of the "m.room.member" invite event that was retired. - TargetUserID string + // The room ID of the "m.room.member" invite event. + RoomID string + // The target sender ID of the "m.room.member" invite event that was retired. + TargetSenderID spec.SenderID // Optional event ID of the event that replaced the invite. // This can be empty if the invite was rejected locally and we were unable // to reach the server that originally sent the invite. diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 6cbaf5b19..af775e350 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -11,6 +11,11 @@ import ( "github.com/matrix-org/util" ) +type SenderUserIDPair struct { + SenderID spec.SenderID + UserID spec.UserID +} + type PerformCreateRoomRequest struct { InvitedUsers []string RoomName string @@ -41,8 +46,8 @@ type PerformJoinRequest struct { } type PerformLeaveRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` + RoomID string + Leaver SenderUserIDPair } type PerformLeaveResponse struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d79dcebbb..891797557 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -115,7 +115,7 @@ type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from RoomID string `json:"room_id"` // ID of the user for whom membership is requested - UserID string `json:"user_id"` + SenderID spec.SenderID `json:"user_id"` } // QueryMembershipForUserResponse is a response to QueryMembership @@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct { // Optional - ID of the user sending the request, for checking if the // user is allowed to see the memberships. If not specified then all // room memberships will be returned. - Sender string `json:"sender"` + SenderID spec.SenderID `json:"sender"` } // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom @@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) } -func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - return rq.Roomserver.InvitePending(ctx, roomID, userID) +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, senderID) } -func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -493,8 +493,8 @@ type MembershipQuerier struct { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { req := QueryMembershipForUserRequest{ - RoomID: roomID.String(), - UserID: string(senderID), + RoomID: roomID.String(), + SenderID: senderID, } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 95397cd5e..eee8e1a55 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -55,9 +55,10 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -99,8 +100,8 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam func IsInvitePending( ctx context.Context, db storage.Database, - roomID, userID string, -) (bool, string, string, gomatrixserverlib.PDU, error) { + roomID string, senderID spec.SenderID, +) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) { // Look up the room NID for the supplied room ID. info, err := db.RoomInfo(ctx, roomID) if err != nil { @@ -111,13 +112,13 @@ func IsInvitePending( } // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)}) if err != nil { return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } - targetUserNID, targetUserFound := targetUserNIDs[userID] + targetUserNID, targetUserFound := targetUserNIDs[string(senderID)] if !targetUserFound { - return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If @@ -156,7 +157,7 @@ func IsInvitePending( event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err + return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err } // GetMembershipsAtState filters the state events to diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index f1896277e..1cef83df7 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" @@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { } // Alice should have no pending invites and should have a NID - pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") // Bob should have no pending invites and receive a new NID - pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") }) diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 98d7d13b1..54cad14f7 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -161,9 +161,10 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -187,9 +188,10 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: newMembership, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index eeb1ac406..a9ce63043 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, userID string, ) (affected []string, err error) { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + fullUserID, err := spec.NewUserID(userID, true) if err != nil { return nil, err } - if !r.Cfg.Matrix.IsLocalServerName(domain) { + if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) { return nil, fmt.Errorf("can only evacuate local users using this endpoint") } @@ -170,9 +170,13 @@ func (r *Admin) PerformAdminEvacuateUser( allRooms := append(roomIDs, inviteRoomIDs...) affected = make([]string, 0, len(allRooms)) for _, roomID := range allRooms { + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return nil, err + } leaveReq := &api.PerformLeaveRequest{ RoomID: roomID, - UserID: userID, + Leaver: api.SenderUserIDPair{SenderID: senderID, UserID: *fullUserID}, } leaveRes := &api.PerformLeaveResponse{} outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index d41cc214b..0052483ab 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID( req.Content = map[string]interface{}{} } req.Content["membership"] = spec.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { return "", "", aerr } else if authorisedVia != "" { req.Content["join_authorised_via_users_server"] = authorisedVia @@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID( // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) + isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) if err == nil && !serverInRoom && isInvitePending { - _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) - if ierr != nil { - return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) } // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { - req.ServerNames = append(req.ServerNames, inviterDomain) + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) // only set unsigned if we've got a content.membership, which we _should_ @@ -299,8 +299,8 @@ func (r *Joiner) performJoinRoomByID( // fail if we can't find the existing membership) because there // is really no harm in just sending another membership event. membershipReq := &api.QueryMembershipForUserRequest{ - RoomID: req.RoomIDOrAlias, - UserID: userID.String(), + RoomID: req.RoomIDOrAlias, + SenderID: senderID, } membershipRes := &api.QueryMembershipForUserResponse{} _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) @@ -376,15 +376,12 @@ func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, + senderID spec.SenderID, ) (string, error) { roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) if err != nil { return "", err } - userID, err := spec.NewUserID(joinReq.UserID, true) - if err != nil { - return "", err - } - return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 094537f8b..8fa324865 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) - } - if !r.Cfg.Matrix.IsLocalServerName(domain) { - return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) + if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.UserID.Domain()) { + return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.UserID.String()) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID, - "user_id": req.UserID, + "user_id": req.Leaver.UserID.String(), }) logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { @@ -84,19 +80,20 @@ func (r *Leaver) performLeaveRoomByID( ) ([]api.OutputEvent, error) { // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) + isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.Leaver.SenderID) if err == nil && isInvitePending { - _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) + if serr != nil || sender == nil { + return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } - if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { - return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) + inviteSender := api.SenderUserIDPair{SenderID: senderUser, UserID: *sender} + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return r.performFederatedRejectInvite(ctx, req, res, inviteSender.UserID, eventID) } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ - UserID: req.UserID, + UserID: req.Leaver.UserID.String(), RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { @@ -127,7 +124,7 @@ func (r *Leaver) performLeaveRoomByID( StateToFetch: []gomatrixserverlib.StateKeyTuple{ { EventType: spec.MRoomMember, - StateKey: req.UserID, + StateKey: string(req.Leaver.SenderID), }, }, } @@ -141,26 +138,18 @@ func (r *Leaver) performLeaveRoomByID( // Now let's see if the user is in the room. if len(latestRes.StateEvents) == 0 { - return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) + return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.UserID.String(), req.RoomID) } membership, err := latestRes.StateEvents[0].Membership() if err != nil { return nil, fmt.Errorf("error getting membership: %w", err) } if membership != spec.Join && membership != spec.Invite { - return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) + return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.UserID.String(), membership) } // Prepare the template for the leave event. - fullUserID, err := spec.NewUserID(req.UserID, true) - if err != nil { - return nil, err - } - senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID) - if err != nil { - return nil, err - } - senderIDString := string(senderID) + senderIDString := string(req.Leaver.SenderID) proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, SenderID: senderIDString, @@ -175,16 +164,13 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } - // Get the sender domain. - senderDomain := fullUserID.Domain() - // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. var buildRes rsAPI.QueryLatestEventsAndStateResponse - identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.UserID.Domain()) if err != nil { return nil, fmt.Errorf("SigningIdentityFor: %w", err) } @@ -201,8 +187,8 @@ func (r *Leaver) performLeaveRoomByID( { Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: req.Leaver.UserID.Domain(), + SendAsServer: string(req.Leaver.UserID.Domain()), }, }, } @@ -219,21 +205,16 @@ func (r *Leaver) performFederatedRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam - senderUser, eventID string, + inviteSender spec.UserID, eventID string, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', senderUser) - if err != nil { - return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err) - } - // Ask the federation sender to perform a federated leave for us. leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, - UserID: req.UserID, - ServerNames: []spec.ServerName{domain}, + UserID: req.Leaver.UserID.String(), + ServerNames: []spec.ServerName{inviteSender.Domain()}, } leaveRes := fsAPI.PerformLeaveResponse{} - if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { // failures in PerformLeave should NEVER stop us from telling other components like the // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") @@ -244,7 +225,7 @@ func (r *Leaver) performFederatedRejectInvite( util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") } - updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) + updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(req.Leaver.SenderID), true, info.RoomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") } @@ -267,9 +248,10 @@ func (r *Leaver) performFederatedRejectInvite( { Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ - EventID: eventID, - Membership: "leave", - TargetUserID: req.UserID, + EventID: eventID, + RoomID: req.RoomID, + Membership: "leave", + TargetSenderID: req.Leaver.SenderID, }, }, }, nil diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 08af41c1e..511c149d8 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -48,7 +48,7 @@ type Queryer struct { Cfg *config.Dendrite } -func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := r.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -236,7 +236,7 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) if err != nil { return err } @@ -373,7 +373,7 @@ func (r *Queryer) QueryMembershipsForRoom( // If no sender is specified then we will just return the entire // set of memberships for the room, regardless of whether a specific // user is allowed to see them or not. - if request.Sender == "" { + if request.SenderID == "" { var events []types.Event var eventNIDs []types.EventNID eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) @@ -396,7 +396,7 @@ func (r *Queryer) QueryMembershipsForRoom( return nil } - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) if err != nil { return err } @@ -903,8 +903,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq return nil } -func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) +func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID) return pending, err } @@ -920,8 +920,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve return res, err } -func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { - _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) +func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) { + _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID) return isIn, err } @@ -951,7 +951,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse } // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) @@ -966,7 +966,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return "", err } - return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 5e6ba7d4e..73460ffd4 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { roomID, _ := spec.NewRoomID(testRoom.ID) userID, _ := spec.NewUserID(bob.ID, true) - got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String())) if tc.wantError && err == nil { t.Fatal("expected error, got none") } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 523cc361a..e9f952490 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -129,7 +129,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f2f842357..30bb4dedc 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -485,10 +485,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) + requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID)) return err }) if err != nil { diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 47eb544ea..fdf3d0799 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -155,6 +155,7 @@ type reqCtx struct { db Database req *EventRelationshipRequest userID string + senderID spec.SenderID roomVersion gomatrixserverlib.RoomVersion // federated request args @@ -173,10 +174,25 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), relation.RoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } rc := reqCtx{ ctx: req.Context(), req: relation, userID: device.UserID, + senderID: senderID, rsAPI: rsAPI, fsAPI: fsAPI, isFederatedRequest: false, @@ -337,8 +353,8 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent // check the user is joined to that room var queryMemRes roomserver.QueryMembershipForUserResponse err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: rc.userID, + RoomID: roomID, + SenderID: rc.senderID, }, &queryMemRes) if err != nil { logger.WithError(err).Warn("failed to query membership for user in room") @@ -538,8 +554,8 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool { // TODO: This does not honour m.room.create content var queryMembershipRes roomserver.QueryMembershipForUserResponse err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: event.RoomID(), - UserID: rc.userID, + RoomID: event.RoomID(), + SenderID: rc.senderID, }, &queryMembershipRes) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 551d7ad45..040e6d958 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str return spec.NewUserID(string(senderID), true) } +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] @@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver } func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { - rooms := r.userToJoinedRooms[req.UserID] + rooms := r.userToJoinedRooms[string(req.SenderID)] for _, roomID := range rooms { if roomID == req.RoomID { res.IsInRoom = true diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 3cd0e77bc..c5f2db9c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -454,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) + if err != nil || userID == nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "sender_id": msg.TargetSenderID, + log.ErrorKey: err, + }).Errorf("failed to find userID for sender") + return + } + s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String()) } func (s *OutputRoomEventConsumer) onNewPeek( diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 144baff1e..1a8d33e38 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -85,9 +85,23 @@ func Context( *filter.Rooms = append(*filter.Rooms, roomID) } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.Unknown("SenderID for this device is unknown"), + } + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} + membershipReq := roomserver.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") return util.JSONResponse{ diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 813167a5e..bd2b37a14 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -59,14 +59,28 @@ func GetMemberships( syncDB storage.Database, rsAPI api.SyncRoomserverAPI, joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.Unknown("SenderID for this device is unknown"), + } + } queryReq := api.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: device.UserID, + RoomID: roomID, + SenderID: senderID, } var queryRes api.QueryMembershipForUserResponse - if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") + if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil { + util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 781fd53e7..91cf289f2 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -296,9 +296,17 @@ func OnIncomingMessagesRequest( } func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return resp, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return resp, err + } req := api.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: userID, + RoomID: roomID, + SenderID: senderID, } if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { return api.QueryMembershipForUserResponse{}, err