diff --git a/internal/httputil/rate_limiting.go b/internal/httputil/rate_limiting.go index dab36481e..23e0dc9ab 100644 --- a/internal/httputil/rate_limiting.go +++ b/internal/httputil/rate_limiting.go @@ -1,6 +1,7 @@ package httputil import ( + "fmt" "net/http" "sync" "time" @@ -9,10 +10,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" + "golang.org/x/time/rate" ) type RateLimits struct { - limits map[string]chan struct{} + limits map[string]deviceRatelimit limitsMutex sync.RWMutex cleanMutex sync.RWMutex enabled bool @@ -21,9 +23,14 @@ type RateLimits struct { exemptUserIDs map[string]struct{} } +type deviceRatelimit struct { + *rate.Limiter + lastUsed time.Time +} + func NewRateLimits(cfg *config.RateLimiting) *RateLimits { l := &RateLimits{ - limits: make(map[string]chan struct{}), + limits: make(map[string]deviceRatelimit), enabled: cfg.Enabled, requestThreshold: cfg.Threshold, cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, @@ -41,15 +48,13 @@ func NewRateLimits(cfg *config.RateLimiting) *RateLimits { func (l *RateLimits) clean() { for { // On a 30 second interval, we'll take an exclusive write - // lock of the entire map and see if any of the channels are - // empty. If they are then we will close and delete them, - // freeing up memory. + // lock of the entire map and see if any of the limiters were used + // more than one minute ago. If they are then we delete them, freeing up memory. time.Sleep(time.Second * 30) l.cleanMutex.Lock() l.limitsMutex.Lock() for k, c := range l.limits { - if len(c) == 0 { - close(c) + if s := time.Since(c.lastUsed); s > time.Minute { delete(l.limits, k) } } @@ -64,13 +69,6 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON return nil } - // Take a read lock out on the cleaner mutex. The cleaner expects to - // be able to take a write lock, which isn't possible while there are - // readers, so this has the effect of blocking the cleaner goroutine - // from doing its work until there are no requests in flight. - l.cleanMutex.RLock() - defer l.cleanMutex.RUnlock() - // First of all, work out if X-Forwarded-For was sent to us. If not // then we'll just use the IP address of the caller. var caller string @@ -100,33 +98,33 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON rateLimit, ok := l.limits[caller] l.limitsMutex.RUnlock() - // If the caller doesn't have a channel, create one and write it + // If the caller doesn't have a rate limit yet, create one and write it // back to the map. if !ok { - rateLimit = make(chan struct{}, l.requestThreshold) + // Create a new limiter allowing 20 burst events, recovering 1 token every l.cooloffDuration + lim := rate.NewLimiter(rate.Every(l.cooloffDuration), 20) + rateLimit = deviceRatelimit{Limiter: lim, lastUsed: time.Now()} l.limitsMutex.Lock() l.limits[caller] = rateLimit l.limitsMutex.Unlock() } + l.limitsMutex.Lock() + rateLimit.lastUsed = time.Now() + l.limits[caller] = rateLimit + l.limitsMutex.Unlock() + // Check if the user has got free resource slots for this request. - // If they don't then we'll return an error. - select { - case rateLimit <- struct{}{}: - default: + // If they don't then we'll try to wait until one is. If the + // context is canceled/done, return an error. + if err := rateLimit.Wait(req.Context()); err != nil { // We hit the rate limit. Tell the client to back off. return &util.JSONResponse{ - Code: http.StatusTooManyRequests, - JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), + Code: http.StatusRequestTimeout, + JSON: jsonerror.Unknown(fmt.Sprintf("Request timed out: %s", err)), } } - // After the time interval, drain a resource from the rate limiting - // channel. This will free up space in the channel for new requests. - go func() { - <-time.After(l.cooloffDuration) - <-rateLimit - }() return nil }