diff --git a/clientapi/routing/rate_limiting.go b/clientapi/routing/rate_limiting.go new file mode 100644 index 000000000..f51d68f09 --- /dev/null +++ b/clientapi/routing/rate_limiting.go @@ -0,0 +1,44 @@ +package routing + +import ( + "net/http" + "sync" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/util" +) + +var clientRateLimits sync.Map // device ID -> chan bool buffered +var clientRateLimitMaxRequests = 10 +var clientRateLimitTimeIntervalMS = time.Millisecond * 500 + +func rateLimit(req *http.Request) *util.JSONResponse { + // Check if the user has got free resource slots for this request. + // If they don't then we'll return an error. + rateLimit, _ := clientRateLimits.LoadOrStore(req.RemoteAddr, make(chan struct{}, clientRateLimitMaxRequests)) + rateChan := rateLimit.(chan struct{}) + select { + case rateChan <- struct{}{}: + default: + // We hit the rate limit. Tell the client to back off. + return &util.JSONResponse{ + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", clientRateLimitTimeIntervalMS.Milliseconds()), + } + } + + // After the time interval, drain a resource from the rate limiting + // channel. This will free up space in the channel for new requests. + go func() { + <-time.After(clientRateLimitTimeIntervalMS) + <-rateChan + + // TODO: racy? + if len(rateChan) == 0 { + close(rateChan) + clientRateLimits.Delete(req.RemoteAddr) + } + }() + return nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 24343ee19..f502105df 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -92,6 +92,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -108,6 +111,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -119,6 +125,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -139,6 +148,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -253,14 +265,23 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) @@ -332,6 +353,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -385,6 +409,9 @@ func Setup( r0mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return Whoami(req, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -393,6 +420,9 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -447,6 +477,9 @@ func Setup( r0mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -469,6 +502,9 @@ func Setup( r0mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -506,6 +542,9 @@ func Setup( // Riot logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } // TODO: Set presence (probably the responsibility of a presence server not clientapi) return util.JSONResponse{ Code: http.StatusOK, @@ -516,6 +555,9 @@ func Setup( r0mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return RequestTurnServer(req, device, cfg) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -582,6 +624,9 @@ func Setup( r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } postContent := struct { SearchString string `json:"search_term"` Limit int `json:"limit"` @@ -623,6 +668,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } // TODO: return the read_markers. return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} }), @@ -721,6 +769,9 @@ func Setup( r0mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimit(req); r != nil { + return *r + } return GetCapabilities(req, rsAPI) }), ).Methods(http.MethodGet) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index ae1b8619b..f60405950 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -27,7 +27,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -46,10 +45,6 @@ type BasicAuth struct { Password string `yaml:"password"` } -var clientRateLimits sync.Map // device ID -> chan bool buffered -var clientRateLimitMaxRequests = 10 -var clientRateLimitTimeIntervalMS = time.Millisecond * 500 - // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( metricsName string, userAPI userapi.UserInternalAPI, @@ -61,27 +56,6 @@ func MakeAuthAPI( return *err } - // Check if the user has got free resource slots for this request. - // If they don't then we'll return an error. - rateLimit, _ := clientRateLimits.LoadOrStore(device.ID, make(chan struct{}, clientRateLimitMaxRequests)) - rateChan := rateLimit.(chan struct{}) - select { - case rateChan <- struct{}{}: - default: - // We hit the rate limit. Tell the client to back off. - return util.JSONResponse{ - Code: http.StatusTooManyRequests, - JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", clientRateLimitTimeIntervalMS.Milliseconds()), - } - } - - // After the time interval, drain a resource from the rate limiting - // channel. This will free up space in the channel for new requests. - go func() { - <-time.After(clientRateLimitTimeIntervalMS) - <-rateChan - }() - // add the user ID to the logger logger := util.GetLogger((req.Context())) logger = logger.WithField("user_id", device.UserID)