Simplify perform_leave api to reduce duplicate senderID query code

This commit is contained in:
Devon Hudson 2023-06-09 12:47:51 -06:00
parent fcf857402b
commit da7afe2e82
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
4 changed files with 25 additions and 30 deletions

View file

@ -36,18 +36,11 @@ func LeaveRoomByID(
JSON: spec.Unknown("device userID is invalid"), JSON: spec.Unknown("device userID is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("could not find senderID for this user"),
}
}
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.
leaveReq := roomserverAPI.PerformLeaveRequest{ leaveReq := roomserverAPI.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
Leaver: roomserverAPI.SenderUserIDPair{SenderID: senderID, UserID: *userID}, Leaver: *userID,
} }
leaveRes := roomserverAPI.PerformLeaveResponse{} leaveRes := roomserverAPI.PerformLeaveResponse{}

View file

@ -47,7 +47,7 @@ type PerformJoinRequest struct {
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string RoomID string
Leaver SenderUserIDPair Leaver spec.UserID
} }
type PerformLeaveResponse struct { type PerformLeaveResponse struct {

View file

@ -170,13 +170,9 @@ func (r *Admin) PerformAdminEvacuateUser(
allRooms := append(roomIDs, inviteRoomIDs...) allRooms := append(roomIDs, inviteRoomIDs...)
affected = make([]string, 0, len(allRooms)) affected = make([]string, 0, len(allRooms))
for _, roomID := range allRooms { for _, roomID := range allRooms {
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
leaveReq := &api.PerformLeaveRequest{ leaveReq := &api.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
Leaver: api.SenderUserIDPair{SenderID: senderID, UserID: *fullUserID}, Leaver: *fullUserID,
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)

View file

@ -53,12 +53,12 @@ func (r *Leaver) PerformLeave(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, res *api.PerformLeaveResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.UserID.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.UserID.String()) return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
} }
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID, "room_id": req.RoomID,
"user_id": req.Leaver.UserID.String(), "user_id": req.Leaver.String(),
}) })
logger.Info("User requested to leave join") logger.Info("User requested to leave join")
if strings.HasPrefix(req.RoomID, "!") { if strings.HasPrefix(req.RoomID, "!") {
@ -78,9 +78,14 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
if err != nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
}
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.Leaver.SenderID) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending { if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
if serr != nil || sender == nil { if serr != nil || sender == nil {
@ -88,12 +93,12 @@ func (r *Leaver) performLeaveRoomByID(
} }
inviteSender := api.SenderUserIDPair{SenderID: senderUser, UserID: *sender} inviteSender := api.SenderUserIDPair{SenderID: senderUser, UserID: *sender}
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return r.performFederatedRejectInvite(ctx, req, res, inviteSender.UserID, eventID) return r.performFederatedRejectInvite(ctx, req, res, inviteSender.UserID, eventID, leaver)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}
if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.Leaver.UserID.String(), UserID: req.Leaver.String(),
RoomID: req.RoomID, RoomID: req.RoomID,
DataType: "m.tag", DataType: "m.tag",
}, accData); err != nil { }, accData); err != nil {
@ -124,7 +129,7 @@ func (r *Leaver) performLeaveRoomByID(
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{
{ {
EventType: spec.MRoomMember, EventType: spec.MRoomMember,
StateKey: string(req.Leaver.SenderID), StateKey: string(leaver),
}, },
}, },
} }
@ -138,18 +143,18 @@ func (r *Leaver) performLeaveRoomByID(
// Now let's see if the user is in the room. // Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 { if len(latestRes.StateEvents) == 0 {
return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.UserID.String(), req.RoomID) return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
} }
membership, err := latestRes.StateEvents[0].Membership() membership, err := latestRes.StateEvents[0].Membership()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting membership: %w", err) return nil, fmt.Errorf("error getting membership: %w", err)
} }
if membership != spec.Join && membership != spec.Invite { if membership != spec.Join && membership != spec.Invite {
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.UserID.String(), membership) return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
senderIDString := string(req.Leaver.SenderID) senderIDString := string(leaver)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
SenderID: senderIDString, SenderID: senderIDString,
@ -170,7 +175,7 @@ func (r *Leaver) performLeaveRoomByID(
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.UserID.Domain()) identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain())
if err != nil { if err != nil {
return nil, fmt.Errorf("SigningIdentityFor: %w", err) return nil, fmt.Errorf("SigningIdentityFor: %w", err)
} }
@ -187,8 +192,8 @@ func (r *Leaver) performLeaveRoomByID(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: req.Leaver.UserID.Domain(), Origin: req.Leaver.Domain(),
SendAsServer: string(req.Leaver.UserID.Domain()), SendAsServer: string(req.Leaver.Domain()),
}, },
}, },
} }
@ -206,11 +211,12 @@ func (r *Leaver) performFederatedRejectInvite(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
inviteSender spec.UserID, eventID string, inviteSender spec.UserID, eventID string,
leaver spec.SenderID,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// Ask the federation sender to perform a federated leave for us. // Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{ leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID, RoomID: req.RoomID,
UserID: req.Leaver.UserID.String(), UserID: req.Leaver.String(),
ServerNames: []spec.ServerName{inviteSender.Domain()}, ServerNames: []spec.ServerName{inviteSender.Domain()},
} }
leaveRes := fsAPI.PerformLeaveResponse{} leaveRes := fsAPI.PerformLeaveResponse{}
@ -225,7 +231,7 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event")
} }
updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(req.Leaver.SenderID), true, info.RoomVersion) updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
} }
@ -251,7 +257,7 @@ func (r *Leaver) performFederatedRejectInvite(
EventID: eventID, EventID: eventID,
RoomID: req.RoomID, RoomID: req.RoomID,
Membership: "leave", Membership: "leave",
TargetSenderID: req.Leaver.SenderID, TargetSenderID: leaver,
}, },
}, },
}, nil }, nil