Use UserID in membership query API to reduce senderID code duplication

This commit is contained in:
Devon Hudson 2023-06-12 08:45:23 +01:00
parent da7afe2e82
commit 8fb8d5a743
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
17 changed files with 84 additions and 178 deletions

View file

@ -145,23 +145,16 @@ func SaveReadMarker(
userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string,
) util.JSONResponse { ) util.JSONResponse {
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("userID for this device is invalid"), JSON: spec.BadJSON("userID for this device is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("failed to find senderID for this user"),
}
}
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -55,14 +55,7 @@ func GetAliases(
visibility = content.HistoryVisibility visibility = content.HistoryVisibility
} }
if visibility != spec.WorldReadable { if visibility != spec.WorldReadable {
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -70,8 +63,8 @@ func GetAliases(
} }
} }
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *deviceUserID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {

View file

@ -314,14 +314,14 @@ func SetVisibility(
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
fullUserID, err := spec.NewUserID(dev.UserID, true) deviceUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("userID for this device is invalid"), JSON: spec.BadJSON("userID for this device is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -329,7 +329,7 @@ func SetVisibility(
} }
} }
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -57,14 +57,14 @@ func SendBan(
} }
} }
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -72,7 +72,7 @@ func SendBan(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -148,14 +148,14 @@ func SendKick(
} }
} }
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -163,7 +163,7 @@ func SendKick(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -187,17 +187,10 @@ func SendKick(
JSON: spec.BadJSON("body userID is invalid"), JSON: spec.BadJSON("body userID is invalid"),
} }
} }
bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.NotFound("body userID has no matching senderID"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: bodySenderID, UserID: *bodyUserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -229,22 +222,15 @@ func SendUnban(
} }
} }
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -256,17 +242,10 @@ func SendUnban(
JSON: spec.BadJSON("body userID is invalid"), JSON: spec.BadJSON("body userID is invalid"),
} }
} }
bodySenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *bodyUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.NotFound("body userID has no matching senderID"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: bodySenderID, UserID: *bodyUserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -317,22 +296,15 @@ func SendInvite(
} }
} }
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -568,11 +540,11 @@ func checkAndProcessThreepid(
return return
} }
func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, senderID spec.SenderID, roomID string) *util.JSONResponse { func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse {
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: userID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user")
@ -597,25 +569,18 @@ func SendForget(
ctx := req.Context() ctx := req.Context()
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
fullUserID, err := spec.NewUserID(device.UserID, true) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
membershipReq := roomserverAPI.QueryMembershipForUserRequest{ membershipReq := roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *deviceUserID,
} }
err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
if err != nil { if err != nil {

View file

@ -47,14 +47,14 @@ func SendRedaction(
txnID *string, txnID *string,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
fullUserID, userIDErr := spec.NewUserID(device.UserID, true) deviceUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil { if userIDErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"), JSON: spec.Forbidden("userID doesn't have power level to redact"),
} }
} }
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if queryErr != nil { if queryErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -62,7 +62,7 @@ func SendRedaction(
} }
} }
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -43,14 +43,7 @@ func SendTyping(
} }
} }
fullUserID, err := spec.NewUserID(userID, true) deviceUserID, err := spec.NewUserID(userID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -59,7 +52,7 @@ func SendTyping(
} }
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, senderID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -188,14 +188,7 @@ func SendServerNotice(
} }
} else { } else {
// we've found a room in common, check the membership // we've found a room in common, check the membership
fullUserID, err := spec.NewUserID(r.UserID, true) deviceUserID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -205,7 +198,7 @@ func SendServerNotice(
roomID = commonRooms[0] roomID = commonRooms[0]
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -107,17 +107,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
JSON: spec.Unknown("Device UserID is invalid"), JSON: spec.Unknown("Device UserID is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device")
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *userID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@ -289,19 +281,11 @@ func OnIncomingStateTypeRequest(
JSON: spec.Unknown("Device UserID is invalid"), JSON: spec.Unknown("Device UserID is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("No matching senderID for this device")
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
// The room isn't world-readable so try to work out based on the // The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not. // user's membership if we want the latest state or not.
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *userID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")

View file

@ -228,6 +228,7 @@ type FederationRoomserverAPI interface {
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error

View file

@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct {
// QueryMembershipForUserRequest is a request to QueryMembership // QueryMembershipForUserRequest is a request to QueryMembership
type QueryMembershipForUserRequest struct { type QueryMembershipForUserRequest struct {
// ID of the room to fetch membership from // ID of the room to fetch membership from
RoomID string `json:"room_id"` RoomID string
// ID of the user for whom membership is requested // ID of the user for whom membership is requested
SenderID spec.SenderID `json:"user_id"` UserID spec.UserID
} }
// QueryMembershipForUserResponse is a response to QueryMembership // QueryMembershipForUserResponse is a response to QueryMembership
@ -492,12 +492,8 @@ type MembershipQuerier struct {
} }
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{
RoomID: roomID.String(),
SenderID: senderID,
}
res := QueryMembershipForUserResponse{} res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
membership := "" membership := ""
if err == nil { if err == nil {

View file

@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID(
} }
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
} }
@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID(
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add // possible candidate for finding the room via federation. Add
// it to the list of servers to try. // it to the list of servers to try.
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, roomID.Domain())
} }
// Prepare the template for the join event. // Prepare the template for the join event.
@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID(
// a member of the room. This is best-effort (as in we won't // a member of the room. This is best-effort (as in we won't
// fail if we can't find the existing membership) because there // fail if we can't find the existing membership) because there
// is really no harm in just sending another membership event. // is really no harm in just sending another membership event.
membershipReq := &api.QueryMembershipForUserRequest{
RoomID: req.RoomIDOrAlias,
SenderID: senderID,
}
membershipRes := &api.QueryMembershipForUserResponse{} membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes)
// If we haven't already joined the room then send an event // If we haven't already joined the room then send an event
// into the room changing our membership status. // into the room changing our membership status.
@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID(
// The room doesn't exist locally. If the room ID looks like it should // The room doesn't exist locally. If the room ID looks like it should
// be ours then this probably means that we've nuked our database at // be ours then this probably means that we've nuked our database at
// some point. // some point.
if r.Cfg.Matrix.IsLocalServerName(domain) { if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
// If there are no more server names to try then give up here. // If there are no more server names to try then give up here.
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.

View file

@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID(
return nil return nil
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI // QueryMembershipForSenderID implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser( func (r *Queryer) QueryMembershipForSenderID(
ctx context.Context, ctx context.Context,
request *api.QueryMembershipForUserRequest, roomID spec.RoomID,
senderID spec.SenderID,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return err return err
} }
@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
if err != nil { if err != nil {
return err return err
} }
@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser(
return err return err
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
if err != nil {
return err
}
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
}
// QueryMembershipAtEvent returns the known memberships at a given event. // QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned // If the state before an event is not known, an empty list will be returned
// for that event instead. // for that event instead.

View file

@ -154,8 +154,7 @@ type reqCtx struct {
rsAPI roomserver.RoomserverInternalAPI rsAPI roomserver.RoomserverInternalAPI
db Database db Database
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID spec.UserID
senderID spec.SenderID
roomVersion gomatrixserverlib.RoomVersion roomVersion gomatrixserverlib.RoomVersion
// federated request args // federated request args
@ -181,18 +180,10 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), relation.RoomID, *userID)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
}
}
rc := reqCtx{ rc := reqCtx{
ctx: req.Context(), ctx: req.Context(),
req: relation, req: relation,
userID: device.UserID, userID: *userID,
senderID: senderID,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI, fsAPI: fsAPI,
isFederatedRequest: false, isFederatedRequest: false,
@ -353,8 +344,8 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent
// check the user is joined to that room // check the user is joined to that room
var queryMemRes roomserver.QueryMembershipForUserResponse var queryMemRes roomserver.QueryMembershipForUserResponse
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: rc.senderID, UserID: rc.userID,
}, &queryMemRes) }, &queryMemRes)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to query membership for user in room") logger.WithError(err).Warn("failed to query membership for user in room")
@ -554,8 +545,8 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool {
// TODO: This does not honour m.room.create content // TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse var queryMembershipRes roomserver.QueryMembershipForUserResponse
err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
SenderID: rc.senderID, UserID: rc.userID,
}, &queryMembershipRes) }, &queryMembershipRes)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser")

View file

@ -544,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
} }
func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
rooms := r.userToJoinedRooms[string(req.SenderID)] rooms := r.userToJoinedRooms[req.UserID.String()]
for _, roomID := range rooms { for _, roomID := range rooms {
if roomID == req.RoomID { if roomID == req.RoomID {
res.IsInRoom = true res.IsInRoom = true

View file

@ -92,16 +92,9 @@ func Context(
JSON: spec.InvalidParam("Device UserID is invalid"), JSON: spec.InvalidParam("Device UserID is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.Unknown("SenderID for this device is unknown"),
}
}
ctx := req.Context() ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}
membershipReq := roomserver.QueryMembershipForUserRequest{SenderID: senderID, RoomID: roomID} membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership") logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -66,16 +66,9 @@ func GetMemberships(
JSON: spec.InvalidParam("Device UserID is invalid"), JSON: spec.InvalidParam("Device UserID is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.Unknown("SenderID for this device is unknown"),
}
}
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *userID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse

View file

@ -300,13 +300,9 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.
if err != nil { if err != nil {
return resp, err return resp, err
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return resp, err
}
req := api.QueryMembershipForUserRequest{ req := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
SenderID: senderID, UserID: *fullUserID,
} }
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err return api.QueryMembershipForUserResponse{}, err