diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2fac3df91..913b18a90 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -397,7 +397,6 @@ func Setup( return GetJoinedRooms(req, device, rsAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) - // TODO: update for cryptoIDs v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { @@ -420,6 +419,28 @@ func Setup( return resp.(util.JSONResponse) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) + unstableMux.Handle("/org.matrix.msc_cryptoids/rooms/{roomID}/join", + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req, device); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + // Only execute a join for roomID and UserID once. If there is a join in progress + // it waits for it to complete and returns that result for subsequent requests. + resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) { + return JoinRoomByIDOrAliasCryptoIDs( + req, device, rsAPI, userAPI, vars["roomID"], + ), nil + }) + // once all joins are processed, drop them from the cache. Further requests + // will be processed as usual. + sf.Forget(vars["roomID"] + device.UserID) + return resp.(util.JSONResponse) + }, httputil.WithAllowGuests()), + ).Methods(http.MethodPost, http.MethodOptions) // TODO: update for cryptoIDs v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/clientapi/routing/send_pdus.go b/clientapi/routing/send_pdus.go index ad2a7119e..9fe240475 100644 --- a/clientapi/routing/send_pdus.go +++ b/clientapi/routing/send_pdus.go @@ -125,33 +125,35 @@ func SendPDUs( JSON: spec.Forbidden("userID doesn't have power level to change visibility"), } } - queryReq := roomserverAPI.QueryMembershipForUserRequest{ - RoomID: pdu.RoomID().String(), - UserID: *deviceUserID, - } - var queryRes roomserverAPI.QueryMembershipForUserResponse - if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, + if !cfg.Matrix.IsLocalServerName(pdu.RoomID().Domain()) { + queryReq := roomserverAPI.QueryMembershipForUserRequest{ + RoomID: pdu.RoomID().String(), + UserID: *deviceUserID, } - } - if !queryRes.IsInRoom { - // This is a join event to a remote room - // TODO: cryptoIDs - figure out how to obtain unsigned contents for outstanding federated invites - joinReq := roomserverAPI.PerformJoinRequestCryptoIDs{ - RoomID: pdu.RoomID().String(), - UserID: device.UserID, - IsGuest: device.AccountType == api.AccountTypeGuest, - ServerNames: []spec.ServerName{spec.ServerName(pdus.ViaServer)}, - JoinEvent: pdu, + var queryRes roomserverAPI.QueryMembershipForUserResponse + if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - err := rsAPI.PerformSendJoinCryptoIDs(req.Context(), &joinReq) - if err != nil { - util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err) + if !queryRes.IsInRoom { + // This is a join event to a remote room + // TODO: cryptoIDs - figure out how to obtain unsigned contents for outstanding federated invites + joinReq := roomserverAPI.PerformJoinRequestCryptoIDs{ + RoomID: pdu.RoomID().String(), + UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, + ServerNames: []spec.ServerName{spec.ServerName(pdus.ViaServer)}, + JoinEvent: pdu, + } + err := rsAPI.PerformSendJoinCryptoIDs(req.Context(), &joinReq) + if err != nil { + util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err) + } + continue // NOTE: don't send this event on to the roomserver } - continue // NOTE: don't send this event on to the roomserver } } }