diff --git a/internal/httputil/rate_limiting.go b/internal/httputil/rate_limiting.go index 699d3688a..dab36481e 100644 --- a/internal/httputil/rate_limiting.go +++ b/internal/httputil/rate_limiting.go @@ -1,7 +1,6 @@ package httputil import ( - "fmt" "net/http" "sync" "time" @@ -10,11 +9,10 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" - "golang.org/x/time/rate" ) type RateLimits struct { - limits map[string]deviceRatelimit + limits map[string]chan struct{} limitsMutex sync.RWMutex cleanMutex sync.RWMutex enabled bool @@ -23,14 +21,9 @@ type RateLimits struct { exemptUserIDs map[string]struct{} } -type deviceRatelimit struct { - *rate.Limiter - lastUsed time.Time -} - func NewRateLimits(cfg *config.RateLimiting) *RateLimits { l := &RateLimits{ - limits: make(map[string]deviceRatelimit), + limits: make(map[string]chan struct{}), enabled: cfg.Enabled, requestThreshold: cfg.Threshold, cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, @@ -48,13 +41,15 @@ func NewRateLimits(cfg *config.RateLimiting) *RateLimits { func (l *RateLimits) clean() { for { // On a 30 second interval, we'll take an exclusive write - // lock of the entire map and see if any of the limiters were used - // more than one minute ago. If they are then we delete them, freeing up memory. + // lock of the entire map and see if any of the channels are + // empty. If they are then we will close and delete them, + // freeing up memory. time.Sleep(time.Second * 30) l.cleanMutex.Lock() l.limitsMutex.Lock() for k, c := range l.limits { - if s := time.Since(c.lastUsed); s > time.Minute { + if len(c) == 0 { + close(c) delete(l.limits, k) } } @@ -105,33 +100,33 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON rateLimit, ok := l.limits[caller] l.limitsMutex.RUnlock() - // If the caller doesn't have a rate limit yet, create one and write it + // If the caller doesn't have a channel, create one and write it // back to the map. if !ok { - // Create a new limiter allowing 20 burst events, recovering 1 token every l.cooloffDuration - lim := rate.NewLimiter(rate.Every(l.cooloffDuration), 20) - rateLimit = deviceRatelimit{Limiter: lim, lastUsed: time.Now()} + rateLimit = make(chan struct{}, l.requestThreshold) l.limitsMutex.Lock() l.limits[caller] = rateLimit l.limitsMutex.Unlock() } - l.limitsMutex.Lock() - rateLimit.lastUsed = time.Now() - l.limits[caller] = rateLimit - l.limitsMutex.Unlock() - // Check if the user has got free resource slots for this request. - // If they don't then we'll try to wait until one is. If the - // context is canceled/done, return an error. - if err := rateLimit.Wait(req.Context()); err != nil { + // If they don't then we'll return an error. + select { + case rateLimit <- struct{}{}: + default: // We hit the rate limit. Tell the client to back off. return &util.JSONResponse{ - Code: http.StatusRequestTimeout, - JSON: jsonerror.Unknown(fmt.Sprintf("Request timed out: %s", err)), + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), } } + // After the time interval, drain a resource from the rate limiting + // channel. This will free up space in the channel for new requests. + go func() { + <-time.After(l.cooloffDuration) + <-rateLimit + }() return nil }