Use rate.Limiter for rate limiting

This commit is contained in:
Till Faelligen 2022-08-02 18:58:20 +02:00
parent 900d5fc031
commit 91bffae350
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E

View file

@ -1,6 +1,7 @@
package httputil package httputil
import ( import (
"fmt"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -9,10 +10,11 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"golang.org/x/time/rate"
) )
type RateLimits struct { type RateLimits struct {
limits map[string]chan struct{} limits map[string]deviceRatelimit
limitsMutex sync.RWMutex limitsMutex sync.RWMutex
cleanMutex sync.RWMutex cleanMutex sync.RWMutex
enabled bool enabled bool
@ -21,9 +23,14 @@ type RateLimits struct {
exemptUserIDs map[string]struct{} exemptUserIDs map[string]struct{}
} }
type deviceRatelimit struct {
*rate.Limiter
lastUsed time.Time
}
func NewRateLimits(cfg *config.RateLimiting) *RateLimits { func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
l := &RateLimits{ l := &RateLimits{
limits: make(map[string]chan struct{}), limits: make(map[string]deviceRatelimit),
enabled: cfg.Enabled, enabled: cfg.Enabled,
requestThreshold: cfg.Threshold, requestThreshold: cfg.Threshold,
cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond,
@ -41,15 +48,13 @@ func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
func (l *RateLimits) clean() { func (l *RateLimits) clean() {
for { for {
// On a 30 second interval, we'll take an exclusive write // On a 30 second interval, we'll take an exclusive write
// lock of the entire map and see if any of the channels are // lock of the entire map and see if any of the limiters were used
// empty. If they are then we will close and delete them, // more than one minute ago. If they are then we delete them, freeing up memory.
// freeing up memory.
time.Sleep(time.Second * 30) time.Sleep(time.Second * 30)
l.cleanMutex.Lock() l.cleanMutex.Lock()
l.limitsMutex.Lock() l.limitsMutex.Lock()
for k, c := range l.limits { for k, c := range l.limits {
if len(c) == 0 { if s := time.Since(c.lastUsed); s > time.Minute {
close(c)
delete(l.limits, k) delete(l.limits, k)
} }
} }
@ -64,13 +69,6 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON
return nil return nil
} }
// Take a read lock out on the cleaner mutex. The cleaner expects to
// be able to take a write lock, which isn't possible while there are
// readers, so this has the effect of blocking the cleaner goroutine
// from doing its work until there are no requests in flight.
l.cleanMutex.RLock()
defer l.cleanMutex.RUnlock()
// First of all, work out if X-Forwarded-For was sent to us. If not // First of all, work out if X-Forwarded-For was sent to us. If not
// then we'll just use the IP address of the caller. // then we'll just use the IP address of the caller.
var caller string var caller string
@ -100,33 +98,33 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON
rateLimit, ok := l.limits[caller] rateLimit, ok := l.limits[caller]
l.limitsMutex.RUnlock() l.limitsMutex.RUnlock()
// If the caller doesn't have a channel, create one and write it // If the caller doesn't have a rate limit yet, create one and write it
// back to the map. // back to the map.
if !ok { if !ok {
rateLimit = make(chan struct{}, l.requestThreshold) // Create a new limiter allowing 20 burst events, recovering 1 token every l.cooloffDuration
lim := rate.NewLimiter(rate.Every(l.cooloffDuration), 20)
rateLimit = deviceRatelimit{Limiter: lim, lastUsed: time.Now()}
l.limitsMutex.Lock() l.limitsMutex.Lock()
l.limits[caller] = rateLimit l.limits[caller] = rateLimit
l.limitsMutex.Unlock() l.limitsMutex.Unlock()
} }
l.limitsMutex.Lock()
rateLimit.lastUsed = time.Now()
l.limits[caller] = rateLimit
l.limitsMutex.Unlock()
// Check if the user has got free resource slots for this request. // Check if the user has got free resource slots for this request.
// If they don't then we'll return an error. // If they don't then we'll try to wait until one is. If the
select { // context is canceled/done, return an error.
case rateLimit <- struct{}{}: if err := rateLimit.Wait(req.Context()); err != nil {
default:
// We hit the rate limit. Tell the client to back off. // We hit the rate limit. Tell the client to back off.
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusTooManyRequests, Code: http.StatusRequestTimeout,
JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), JSON: jsonerror.Unknown(fmt.Sprintf("Request timed out: %s", err)),
} }
} }
// After the time interval, drain a resource from the rate limiting
// channel. This will free up space in the channel for new requests.
go func() {
<-time.After(l.cooloffDuration)
<-rateLimit
}()
return nil return nil
} }