Move rate limiting to client API

This commit is contained in:
Neil Alexander 2020-09-02 17:06:37 +01:00
parent b2f61a31cd
commit e5cb4f58a9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 95 additions and 26 deletions

View file

@ -0,0 +1,44 @@
package routing
import (
"net/http"
"sync"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/util"
)
var clientRateLimits sync.Map // device ID -> chan bool buffered
var clientRateLimitMaxRequests = 10
var clientRateLimitTimeIntervalMS = time.Millisecond * 500
func rateLimit(req *http.Request) *util.JSONResponse {
// Check if the user has got free resource slots for this request.
// If they don't then we'll return an error.
rateLimit, _ := clientRateLimits.LoadOrStore(req.RemoteAddr, make(chan struct{}, clientRateLimitMaxRequests))
rateChan := rateLimit.(chan struct{})
select {
case rateChan <- struct{}{}:
default:
// We hit the rate limit. Tell the client to back off.
return &util.JSONResponse{
Code: http.StatusTooManyRequests,
JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", clientRateLimitTimeIntervalMS.Milliseconds()),
}
}
// After the time interval, drain a resource from the rate limiting
// channel. This will free up space in the channel for new requests.
go func() {
<-time.After(clientRateLimitTimeIntervalMS)
<-rateChan
// TODO: racy?
if len(rateChan) == 0 {
close(rateChan)
clientRateLimits.Delete(req.RemoteAddr)
}
}()
return nil
}

View file

@ -92,6 +92,9 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}", r0mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -108,6 +111,9 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join", r0mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -119,6 +125,9 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/leave", r0mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -139,6 +148,9 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/invite", r0mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -253,14 +265,23 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return Register(req, userAPI, accountDB, cfg) return Register(req, userAPI, accountDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return LegacyRegister(req, userAPI, cfg) return LegacyRegister(req, userAPI, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return RegisterAvailable(req, cfg, accountDB) return RegisterAvailable(req, cfg, accountDB)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
@ -332,6 +353,9 @@ func Setup(
r0mux.Handle("/rooms/{roomID}/typing/{userID}", r0mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -385,6 +409,9 @@ func Setup(
r0mux.Handle("/account/whoami", r0mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return Whoami(req, device) return Whoami(req, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -393,6 +420,9 @@ func Setup(
r0mux.Handle("/login", r0mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return Login(req, accountDB, userAPI, cfg) return Login(req, accountDB, userAPI, cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
@ -447,6 +477,9 @@ func Setup(
r0mux.Handle("/profile/{userID}/avatar_url", r0mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -469,6 +502,9 @@ func Setup(
r0mux.Handle("/profile/{userID}/displayname", r0mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -506,6 +542,9 @@ func Setup(
// Riot logs get flooded unless this is handled // Riot logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status", r0mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
// TODO: Set presence (probably the responsibility of a presence server not clientapi) // TODO: Set presence (probably the responsibility of a presence server not clientapi)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -516,6 +555,9 @@ func Setup(
r0mux.Handle("/voip/turnServer", r0mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return RequestTurnServer(req, device, cfg) return RequestTurnServer(req, device, cfg)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -582,6 +624,9 @@ func Setup(
r0mux.Handle("/user_directory/search", r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
postContent := struct { postContent := struct {
SearchString string `json:"search_term"` SearchString string `json:"search_term"`
Limit int `json:"limit"` Limit int `json:"limit"`
@ -623,6 +668,9 @@ func Setup(
r0mux.Handle("/rooms/{roomID}/read_markers", r0mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
// TODO: return the read_markers. // TODO: return the read_markers.
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}), }),
@ -721,6 +769,9 @@ func Setup(
r0mux.Handle("/capabilities", r0mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimit(req); r != nil {
return *r
}
return GetCapabilities(req, rsAPI) return GetCapabilities(req, rsAPI)
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)

View file

@ -27,7 +27,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -46,10 +45,6 @@ type BasicAuth struct {
Password string `yaml:"password"` Password string `yaml:"password"`
} }
var clientRateLimits sync.Map // device ID -> chan bool buffered
var clientRateLimitMaxRequests = 10
var clientRateLimitTimeIntervalMS = time.Millisecond * 500
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI( func MakeAuthAPI(
metricsName string, userAPI userapi.UserInternalAPI, metricsName string, userAPI userapi.UserInternalAPI,
@ -61,27 +56,6 @@ func MakeAuthAPI(
return *err return *err
} }
// Check if the user has got free resource slots for this request.
// If they don't then we'll return an error.
rateLimit, _ := clientRateLimits.LoadOrStore(device.ID, make(chan struct{}, clientRateLimitMaxRequests))
rateChan := rateLimit.(chan struct{})
select {
case rateChan <- struct{}{}:
default:
// We hit the rate limit. Tell the client to back off.
return util.JSONResponse{
Code: http.StatusTooManyRequests,
JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", clientRateLimitTimeIntervalMS.Milliseconds()),
}
}
// After the time interval, drain a resource from the rate limiting
// channel. This will free up space in the channel for new requests.
go func() {
<-time.After(clientRateLimitTimeIntervalMS)
<-rateChan
}()
// add the user ID to the logger // add the user ID to the logger
logger := util.GetLogger((req.Context())) logger := util.GetLogger((req.Context()))
logger = logger.WithField("user_id", device.UserID) logger = logger.WithField("user_id", device.UserID)