Fix senderID usage in join/leave

This commit is contained in:
Devon Hudson 2023-06-08 13:55:38 -06:00
parent 087b2b8727
commit f6b6f5e8db
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
31 changed files with 447 additions and 212 deletions

View file

@ -145,8 +145,23 @@ func SaveReadMarker(
userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string,
) util.JSONResponse { ) util.JSONResponse {
fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -55,9 +55,23 @@ func GetAliases(
visibility = content.HistoryVisibility visibility = content.HistoryVisibility
} }
if visibility != spec.WorldReadable { if visibility != spec.WorldReadable {
fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, SenderID: senderID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {

View file

@ -314,30 +314,6 @@ func SetVisibility(
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID)
if resErr != nil {
return *resErr
}
queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{
RoomID: roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{{
EventType: spec.MRoomPowerLevels,
StateKey: "",
}},
}
var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse
err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
if err != nil || len(queryEventsRes.StateEvents) == 0 {
util.GetLogger(req.Context()).WithError(err).Error("could not query events from room")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
fullUserID, err := spec.NewUserID(dev.UserID, true) fullUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -352,6 +328,31 @@ func SetVisibility(
JSON: spec.Forbidden("userID doesn't have power level to change visibility"), JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
} }
} }
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if resErr != nil {
return *resErr
}
queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{
RoomID: roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{{
EventType: spec.MRoomPowerLevels,
StateKey: "",
}},
}
var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse
err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
if err != nil || len(queryEventsRes.StateEvents) == 0 {
util.GetLogger(req.Context()).WithError(err).Error("could not query events from room")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -29,10 +29,25 @@ func LeaveRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("device userID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("could not find senderID for this user"),
}
}
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.
leaveReq := roomserverAPI.PerformLeaveRequest{ leaveReq := roomserverAPI.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, Leaver: roomserverAPI.SenderUserIDPair{SenderID: senderID, UserID: *userID},
} }
leaveRes := roomserverAPI.PerformLeaveResponse{} leaveRes := roomserverAPI.PerformLeaveResponse{}

View file

@ -57,15 +57,6 @@ func SendBan(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
if errRes != nil {
return *errRes
}
pl, errRes := getPowerlevels(req, rsAPI, roomID)
if errRes != nil {
return *errRes
}
fullUserID, err := spec.NewUserID(device.UserID, true) fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -80,6 +71,16 @@ func SendBan(
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if errRes != nil {
return *errRes
}
pl, errRes := getPowerlevels(req, rsAPI, roomID)
if errRes != nil {
return *errRes
}
allowedToBan := pl.UserLevel(senderID) >= pl.Ban allowedToBan := pl.UserLevel(senderID) >= pl.Ban
if !allowedToBan { if !allowedToBan {
return util.JSONResponse{ return util.JSONResponse{
@ -147,15 +148,6 @@ func SendKick(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
if errRes != nil {
return *errRes
}
pl, errRes := getPowerlevels(req, rsAPI, roomID)
if errRes != nil {
return *errRes
}
fullUserID, err := spec.NewUserID(device.UserID, true) fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -170,6 +162,16 @@ func SendKick(
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if errRes != nil {
return *errRes
}
pl, errRes := getPowerlevels(req, rsAPI, roomID)
if errRes != nil {
return *errRes
}
allowedToKick := pl.UserLevel(senderID) >= pl.Kick allowedToKick := pl.UserLevel(senderID) >= pl.Kick
if !allowedToKick { if !allowedToKick {
return util.JSONResponse{ return util.JSONResponse{
@ -178,10 +180,24 @@ func SendKick(
} }
} }
bodyUserID, err := spec.NewUserID(body.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("body userID is invalid"),
}
}
bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.NotFound("body userID has no matching senderID"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, SenderID: bodySenderID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -213,15 +229,44 @@ func SendUnban(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
bodyUserID, err := spec.NewUserID(body.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("body userID is invalid"),
}
}
bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.NotFound("body userID has no matching senderID"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, SenderID: bodySenderID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -272,7 +317,22 @@ func SendInvite(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -508,11 +568,11 @@ func checkAndProcessThreepid(
return return
} }
func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, senderID spec.SenderID, roomID string) *util.JSONResponse {
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: userID, SenderID: senderID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user")
@ -536,12 +596,28 @@ func SendForget(
) util.JSONResponse { ) util.JSONResponse {
ctx := req.Context() ctx := req.Context()
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
membershipReq := roomserverAPI.QueryMembershipForUserRequest{ membershipReq := roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, SenderID: senderID,
} }
err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
if err != nil { if err != nil {
logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -47,7 +47,22 @@ func SendRedaction(
txnID *string, txnID *string,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -73,21 +88,6 @@ func SendRedaction(
} }
} }
fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid

View file

@ -43,8 +43,23 @@ func SendTyping(
} }
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -52,6 +52,7 @@ type sendServerNoticeRequest struct {
StateKey string `json:"state_key,omitempty"` StateKey string `json:"state_key,omitempty"`
} }
// nolint:gocyclo
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin. // SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
func SendServerNotice( func SendServerNotice(
req *http.Request, req *http.Request,
@ -187,9 +188,24 @@ func SendServerNotice(
} }
} else { } else {
// we've found a room in common, check the membership // we've found a room in common, check the membership
fullUserID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
roomID = commonRooms[0] roomID = commonRooms[0]
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
return util.JSONResponse{ return util.JSONResponse{
@ -234,7 +250,7 @@ func SendServerNotice(
ctx, rsAPI, ctx, rsAPI,
api.KindNew, api.KindNew,
[]*types.HeaderedEvent{ []*types.HeaderedEvent{
&types.HeaderedEvent{PDU: e}, {PDU: e},
}, },
device.UserDomain(), device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,

View file

@ -99,9 +99,25 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if !worldReadable { if !worldReadable {
// The room isn't world-readable so try to work out based on the // The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not. // user's membership if we want the latest state or not.
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ userID, err := spec.NewUserID(device.UserID, true)
RoomID: roomID, if err != nil {
UserID: device.UserID, util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("Device UserID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device")
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID,
SenderID: senderID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@ -265,11 +281,27 @@ func OnIncomingStateTypeRequest(
// membershipRes will only be populated if the room is not world-readable. // membershipRes will only be populated if the room is not world-readable.
var membershipRes api.QueryMembershipForUserResponse var membershipRes api.QueryMembershipForUserResponse
if !worldReadable { if !worldReadable {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("Device UserID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device")
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
// The room isn't world-readable so try to work out based on the // The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not. // user's membership if we want the latest state or not.
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, SenderID: senderID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 h1:WkmHgdQvpPBxGo/huuCQVR9FWmdkKOcAuQHZBmcHK3g= github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483 h1:T7y2bnp0l0JVGqegAmpOdkaENFiA7raSNSdYz0PXxaI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230608195455-508c2cae4483/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string {
type RestrictedJoinAPI interface { type RestrictedJoinAPI interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
} }
@ -238,15 +238,13 @@ type FederationRoomserverAPI interface {
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
// the state and auth chain to return. // the state and auth chain to return.
QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error
// Query if we think we're still in a room.
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
// Query missing events for a room from roomserver // Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
@ -254,12 +252,6 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error)
StateQuerier() gomatrixserverlib.StateQuerier StateQuerier() gomatrixserverlib.StateQuerier
} }

View file

@ -215,8 +215,10 @@ type OutputNewInviteEvent struct {
type OutputRetireInviteEvent struct { type OutputRetireInviteEvent struct {
// The ID of the "m.room.member" invite event. // The ID of the "m.room.member" invite event.
EventID string EventID string
// The target user ID of the "m.room.member" invite event that was retired. // The room ID of the "m.room.member" invite event.
TargetUserID string RoomID string
// The target sender ID of the "m.room.member" invite event that was retired.
TargetSenderID spec.SenderID
// Optional event ID of the event that replaced the invite. // Optional event ID of the event that replaced the invite.
// This can be empty if the invite was rejected locally and we were unable // This can be empty if the invite was rejected locally and we were unable
// to reach the server that originally sent the invite. // to reach the server that originally sent the invite.

View file

@ -11,6 +11,11 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type SenderUserIDPair struct {
SenderID spec.SenderID
UserID spec.UserID
}
type PerformCreateRoomRequest struct { type PerformCreateRoomRequest struct {
InvitedUsers []string InvitedUsers []string
RoomName string RoomName string
@ -41,8 +46,8 @@ type PerformJoinRequest struct {
} }
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string `json:"room_id"` RoomID string
UserID string `json:"user_id"` Leaver SenderUserIDPair
} }
type PerformLeaveResponse struct { type PerformLeaveResponse struct {

View file

@ -115,7 +115,7 @@ type QueryMembershipForUserRequest struct {
// ID of the room to fetch membership from // ID of the room to fetch membership from
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
// ID of the user for whom membership is requested // ID of the user for whom membership is requested
UserID string `json:"user_id"` SenderID spec.SenderID `json:"user_id"`
} }
// QueryMembershipForUserResponse is a response to QueryMembership // QueryMembershipForUserResponse is a response to QueryMembership
@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct {
// Optional - ID of the user sending the request, for checking if the // Optional - ID of the user sending the request, for checking if the
// user is allowed to see the memberships. If not specified then all // user is allowed to see the memberships. If not specified then all
// room memberships will be returned. // room memberships will be returned.
Sender string `json:"sender"` SenderID spec.SenderID `json:"sender"`
} }
// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro
return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
} }
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
return rq.Roomserver.InvitePending(ctx, roomID, userID) return rq.Roomserver.InvitePending(ctx, roomID, senderID)
} }
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() { if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err return nil, err
@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
} }
userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err) return nil, fmt.Errorf("InternalServerError: %w", err)
@ -493,8 +493,8 @@ type MembershipQuerier struct {
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{ req := QueryMembershipForUserRequest{
RoomID: roomID.String(), RoomID: roomID.String(),
UserID: string(senderID), SenderID: senderID,
} }
res := QueryMembershipForUserResponse{} res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)

View file

@ -55,9 +55,10 @@ func UpdateToInviteMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }
@ -99,8 +100,8 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
func IsInvitePending( func IsInvitePending(
ctx context.Context, db storage.Database, ctx context.Context, db storage.Database,
roomID, userID string, roomID string, senderID spec.SenderID,
) (bool, string, string, gomatrixserverlib.PDU, error) { ) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
// Look up the room NID for the supplied room ID. // Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID) info, err := db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -111,13 +112,13 @@ func IsInvitePending(
} }
// Look up the state key NID for the supplied user ID. // Look up the state key NID for the supplied user ID.
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil { if err != nil {
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
} }
targetUserNID, targetUserFound := targetUserNIDs[userID] targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
if !targetUserFound { if !targetUserFound {
return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
} }
// Let's see if we have an event active for the user in the room. If // Let's see if we have an event active for the user in the room. If
@ -156,7 +157,7 @@ func IsInvitePending(
event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
} }
// GetMembershipsAtState filters the state events to // GetMembershipsAtState filters the state events to

View file

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
} }
// Alice should have no pending invites and should have a NID // Alice should have no pending invites and should have a NID
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID))
assert.NoError(t, err, "failed to get pending invites") assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite") assert.False(t, pendingInvite, "unexpected pending invite")
// Bob should have no pending invites and receive a new NID // Bob should have no pending invites and receive a new NID
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID))
assert.NoError(t, err, "failed to get pending invites") assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite") assert.False(t, pendingInvite, "unexpected pending invite")
}) })

View file

@ -161,9 +161,10 @@ func updateToJoinMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }
@ -187,9 +188,10 @@ func updateToLeaveMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: newMembership, Membership: newMembership,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }

View file

@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
userID string, userID string,
) (affected []string, err error) { ) (affected []string, err error) {
_, domain, err := gomatrixserverlib.SplitID('@', userID) fullUserID, err := spec.NewUserID(userID, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) {
return nil, fmt.Errorf("can only evacuate local users using this endpoint") return nil, fmt.Errorf("can only evacuate local users using this endpoint")
} }
@ -170,9 +170,13 @@ func (r *Admin) PerformAdminEvacuateUser(
allRooms := append(roomIDs, inviteRoomIDs...) allRooms := append(roomIDs, inviteRoomIDs...)
affected = make([]string, 0, len(allRooms)) affected = make([]string, 0, len(allRooms))
for _, roomID := range allRooms { for _, roomID := range allRooms {
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
leaveReq := &api.PerformLeaveRequest{ leaveReq := &api.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: userID, Leaver: api.SenderUserIDPair{SenderID: senderID, UserID: *fullUserID},
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)

View file

@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID(
req.Content = map[string]interface{}{} req.Content = map[string]interface{}{}
} }
req.Content["membership"] = spec.Join req.Content["membership"] = spec.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return "", "", aerr return "", "", aerr
} else if authorisedVia != "" { } else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia req.Content["join_authorised_via_users_server"] = authorisedVia
@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID(
// Force a federated join if we're dealing with a pending invite // Force a federated join if we're dealing with a pending invite
// and we aren't in the room. // and we aren't in the room.
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if err == nil && !serverInRoom && isInvitePending { if err == nil && !serverInRoom && isInvitePending {
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
if ierr != nil { if queryErr != nil {
return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
} }
// If we were invited by someone from another server then we can // If we were invited by someone from another server then we can
// assume they are in the room so we can join via them. // assume they are in the room so we can join via them.
if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
req.ServerNames = append(req.ServerNames, inviterDomain) req.ServerNames = append(req.ServerNames, inviter.Domain())
forceFederatedJoin = true forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON())) memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_ // only set unsigned if we've got a content.membership, which we _should_
@ -299,8 +299,8 @@ func (r *Joiner) performJoinRoomByID(
// fail if we can't find the existing membership) because there // fail if we can't find the existing membership) because there
// is really no harm in just sending another membership event. // is really no harm in just sending another membership event.
membershipReq := &api.QueryMembershipForUserRequest{ membershipReq := &api.QueryMembershipForUserRequest{
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
UserID: userID.String(), SenderID: senderID,
} }
membershipRes := &api.QueryMembershipForUserResponse{} membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
@ -376,15 +376,12 @@ func (r *Joiner) performFederatedJoinRoomByID(
func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
ctx context.Context, ctx context.Context,
joinReq *rsAPI.PerformJoinRequest, joinReq *rsAPI.PerformJoinRequest,
senderID spec.SenderID,
) (string, error) { ) (string, error) {
roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias)
if err != nil { if err != nil {
return "", err return "", err
} }
userID, err := spec.NewUserID(joinReq.UserID, true)
if err != nil {
return "", err
}
return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID)
} }

View file

@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, res *api.PerformLeaveResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.UserID.Domain()) {
if err != nil { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.UserID.String())
return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)
}
if !r.Cfg.Matrix.IsLocalServerName(domain) {
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
} }
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID, "room_id": req.RoomID,
"user_id": req.UserID, "user_id": req.Leaver.UserID.String(),
}) })
logger.Info("User requested to leave join") logger.Info("User requested to leave join")
if strings.HasPrefix(req.RoomID, "!") { if strings.HasPrefix(req.RoomID, "!") {
@ -84,19 +80,20 @@ func (r *Leaver) performLeaveRoomByID(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.Leaver.SenderID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
if serr != nil { if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q is invalid", senderUser) return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
} }
if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { inviteSender := api.SenderUserIDPair{SenderID: senderUser, UserID: *sender}
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return r.performFederatedRejectInvite(ctx, req, res, inviteSender.UserID, eventID)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}
if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.UserID, UserID: req.Leaver.UserID.String(),
RoomID: req.RoomID, RoomID: req.RoomID,
DataType: "m.tag", DataType: "m.tag",
}, accData); err != nil { }, accData); err != nil {
@ -127,7 +124,7 @@ func (r *Leaver) performLeaveRoomByID(
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{
{ {
EventType: spec.MRoomMember, EventType: spec.MRoomMember,
StateKey: req.UserID, StateKey: string(req.Leaver.SenderID),
}, },
}, },
} }
@ -141,26 +138,18 @@ func (r *Leaver) performLeaveRoomByID(
// Now let's see if the user is in the room. // Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 { if len(latestRes.StateEvents) == 0 {
return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.UserID.String(), req.RoomID)
} }
membership, err := latestRes.StateEvents[0].Membership() membership, err := latestRes.StateEvents[0].Membership()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting membership: %w", err) return nil, fmt.Errorf("error getting membership: %w", err)
} }
if membership != spec.Join && membership != spec.Invite { if membership != spec.Join && membership != spec.Invite {
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.UserID.String(), membership)
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
fullUserID, err := spec.NewUserID(req.UserID, true) senderIDString := string(req.Leaver.SenderID)
if err != nil {
return nil, err
}
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
SenderID: senderIDString, SenderID: senderIDString,
@ -175,16 +164,13 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eb.SetUnsigned: %w", err) return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// Get the sender domain.
senderDomain := fullUserID.Domain()
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.UserID.Domain())
if err != nil { if err != nil {
return nil, fmt.Errorf("SigningIdentityFor: %w", err) return nil, fmt.Errorf("SigningIdentityFor: %w", err)
} }
@ -201,8 +187,8 @@ func (r *Leaver) performLeaveRoomByID(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: senderDomain, Origin: req.Leaver.UserID.Domain(),
SendAsServer: string(senderDomain), SendAsServer: string(req.Leaver.UserID.Domain()),
}, },
}, },
} }
@ -219,21 +205,16 @@ func (r *Leaver) performFederatedRejectInvite(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
senderUser, eventID string, inviteSender spec.UserID, eventID string,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
if err != nil {
return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err)
}
// Ask the federation sender to perform a federated leave for us. // Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{ leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID, RoomID: req.RoomID,
UserID: req.UserID, UserID: req.Leaver.UserID.String(),
ServerNames: []spec.ServerName{domain}, ServerNames: []spec.ServerName{inviteSender.Domain()},
} }
leaveRes := fsAPI.PerformLeaveResponse{} leaveRes := fsAPI.PerformLeaveResponse{}
if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
// failures in PerformLeave should NEVER stop us from telling other components like the // failures in PerformLeave should NEVER stop us from telling other components like the
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
@ -244,7 +225,7 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event")
} }
updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(req.Leaver.SenderID), true, info.RoomVersion)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
} }
@ -267,9 +248,10 @@ func (r *Leaver) performFederatedRejectInvite(
{ {
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: "leave", RoomID: req.RoomID,
TargetUserID: req.UserID, Membership: "leave",
TargetSenderID: req.Leaver.SenderID,
}, },
}, },
}, nil }, nil

View file

@ -48,7 +48,7 @@ type Queryer struct {
Cfg *config.Dendrite Cfg *config.Dendrite
} }
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := r.QueryRoomInfo(ctx, roomID) roomInfo, err := r.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() { if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err return nil, err
@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
} }
userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err) return nil, fmt.Errorf("InternalServerError: %w", err)
@ -236,7 +236,7 @@ func (r *Queryer) QueryMembershipForUser(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil { if err != nil {
return err return err
} }
@ -373,7 +373,7 @@ func (r *Queryer) QueryMembershipsForRoom(
// If no sender is specified then we will just return the entire // If no sender is specified then we will just return the entire
// set of memberships for the room, regardless of whether a specific // set of memberships for the room, regardless of whether a specific
// user is allowed to see them or not. // user is allowed to see them or not.
if request.Sender == "" { if request.SenderID == "" {
var events []types.Event var events []types.Event
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
@ -396,7 +396,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return nil return nil
} }
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil { if err != nil {
return err return err
} }
@ -903,8 +903,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
return nil return nil
} }
func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID)
return pending, err return pending, err
} }
@ -920,8 +920,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
return res, err return res, err
} }
func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
_, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
return isIn, err return isIn, err
} }
@ -951,7 +951,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist // Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything. // or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
@ -966,7 +966,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return "", err return "", err
} }
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {

View file

@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
roomID, _ := spec.NewRoomID(testRoom.ID) roomID, _ := spec.NewRoomID(testRoom.ID)
userID, _ := spec.NewUserID(bob.ID, true) userID, _ := spec.NewUserID(bob.ID, true)
got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String()))
if tc.wantError && err == nil { if tc.wantError && err == nil {
t.Fatal("expected error, got none") t.Fatal("expected error, got none")
} }

View file

@ -129,7 +129,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.

View file

@ -485,10 +485,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
}) })
} }
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
var requestSenderUserNID types.EventStateKeyNID var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID))
return err return err
}) })
if err != nil { if err != nil {

View file

@ -155,6 +155,7 @@ type reqCtx struct {
db Database db Database
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID string
senderID spec.SenderID
roomVersion gomatrixserverlib.RoomVersion roomVersion gomatrixserverlib.RoomVersion
// federated request args // federated request args
@ -173,10 +174,25 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
} }
} }
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), relation.RoomID, *userID)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
}
}
rc := reqCtx{ rc := reqCtx{
ctx: req.Context(), ctx: req.Context(),
req: relation, req: relation,
userID: device.UserID, userID: device.UserID,
senderID: senderID,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI, fsAPI: fsAPI,
isFederatedRequest: false, isFederatedRequest: false,
@ -337,8 +353,8 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent
// check the user is joined to that room // check the user is joined to that room
var queryMemRes roomserver.QueryMembershipForUserResponse var queryMemRes roomserver.QueryMembershipForUserResponse
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: rc.userID, SenderID: rc.senderID,
}, &queryMemRes) }, &queryMemRes)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to query membership for user in room") logger.WithError(err).Warn("failed to query membership for user in room")
@ -538,8 +554,8 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool {
// TODO: This does not honour m.room.create content // TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse var queryMembershipRes roomserver.QueryMembershipForUserResponse
err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
UserID: rc.userID, SenderID: rc.senderID,
}, &queryMembershipRes) }, &queryMembershipRes)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser")

View file

@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil
}
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs { for _, eventID := range req.EventIDs {
ev := r.events[eventID] ev := r.events[eventID]
@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
} }
func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
rooms := r.userToJoinedRooms[req.UserID] rooms := r.userToJoinedRooms[string(req.SenderID)]
for _, roomID := range rooms { for _, roomID := range rooms {
if roomID == req.RoomID { if roomID == req.RoomID {
res.IsInRoom = true res.IsInRoom = true

View file

@ -454,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos) s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
if err != nil || userID == nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"sender_id": msg.TargetSenderID,
log.ErrorKey: err,
}).Errorf("failed to find userID for sender")
return
}
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
} }
func (s *OutputRoomEventConsumer) onNewPeek( func (s *OutputRoomEventConsumer) onNewPeek(

View file

@ -85,9 +85,23 @@ func Context(
*filter.Rooms = append(*filter.Rooms, roomID) *filter.Rooms = append(*filter.Rooms, roomID)
} }
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.Unknown("SenderID for this device is unknown"),
}
}
ctx := req.Context() ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} membershipReq := roomserver.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership") logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -59,14 +59,28 @@ func GetMemberships(
syncDB storage.Database, rsAPI api.SyncRoomserverAPI, syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
joinedOnly bool, membership, notMembership *string, at string, joinedOnly bool, membership, notMembership *string, at string,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.Unknown("SenderID for this device is unknown"),
}
}
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, SenderID: senderID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},

View file

@ -296,9 +296,17 @@ func OnIncomingMessagesRequest(
} }
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return resp, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return resp, err
}
req := api.QueryMembershipForUserRequest{ req := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: userID, SenderID: senderID,
} }
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err return api.QueryMembershipForUserResponse{}, err