mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-09 07:03:10 -06:00
138 lines
3.8 KiB
Go
138 lines
3.8 KiB
Go
package httputil
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
"github.com/matrix-org/dendrite/setup/config"
|
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
"github.com/matrix-org/util"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type RateLimits struct {
|
|
limits map[string]deviceRatelimit
|
|
limitsMutex sync.RWMutex
|
|
cleanMutex sync.RWMutex
|
|
enabled bool
|
|
requestThreshold int64
|
|
cooloffDuration time.Duration
|
|
exemptUserIDs map[string]struct{}
|
|
}
|
|
|
|
type deviceRatelimit struct {
|
|
*rate.Limiter
|
|
lastUsed time.Time
|
|
}
|
|
|
|
func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
|
|
l := &RateLimits{
|
|
limits: make(map[string]deviceRatelimit),
|
|
enabled: cfg.Enabled,
|
|
requestThreshold: cfg.Threshold,
|
|
cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond,
|
|
exemptUserIDs: map[string]struct{}{},
|
|
}
|
|
for _, userID := range cfg.ExemptUserIDs {
|
|
l.exemptUserIDs[userID] = struct{}{}
|
|
}
|
|
if l.enabled {
|
|
go l.clean()
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (l *RateLimits) clean() {
|
|
for {
|
|
// On a 30 second interval, we'll take an exclusive write
|
|
// lock of the entire map and see if any of the limiters were used
|
|
// more than one minute ago. If they are then we delete them, freeing up memory.
|
|
time.Sleep(time.Second * 30)
|
|
l.cleanMutex.Lock()
|
|
l.limitsMutex.Lock()
|
|
for k, c := range l.limits {
|
|
if s := time.Since(c.lastUsed); s > time.Minute {
|
|
delete(l.limits, k)
|
|
}
|
|
}
|
|
l.limitsMutex.Unlock()
|
|
l.cleanMutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSONResponse {
|
|
// If rate limiting is disabled then do nothing.
|
|
if !l.enabled {
|
|
return nil
|
|
}
|
|
|
|
// Take a read lock out on the cleaner mutex. The cleaner expects to
|
|
// be able to take a write lock, which isn't possible while there are
|
|
// readers, so this has the effect of blocking the cleaner goroutine
|
|
// from doing its work until there are no requests in flight.
|
|
l.cleanMutex.RLock()
|
|
defer l.cleanMutex.RUnlock()
|
|
|
|
// First of all, work out if X-Forwarded-For was sent to us. If not
|
|
// then we'll just use the IP address of the caller.
|
|
var caller string
|
|
if device != nil {
|
|
switch device.AccountType {
|
|
case userapi.AccountTypeAdmin:
|
|
return nil // don't rate-limit server administrators
|
|
case userapi.AccountTypeAppService:
|
|
return nil // don't rate-limit appservice users
|
|
default:
|
|
if _, ok := l.exemptUserIDs[device.UserID]; ok {
|
|
// If the user is exempt from rate limiting then do nothing.
|
|
return nil
|
|
}
|
|
caller = device.UserID + device.ID
|
|
}
|
|
} else {
|
|
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
|
caller = forwardedFor
|
|
} else {
|
|
caller = req.RemoteAddr
|
|
}
|
|
}
|
|
|
|
// Look up the caller's channel, if they have one.
|
|
l.limitsMutex.RLock()
|
|
rateLimit, ok := l.limits[caller]
|
|
l.limitsMutex.RUnlock()
|
|
|
|
// If the caller doesn't have a rate limit yet, create one and write it
|
|
// back to the map.
|
|
if !ok {
|
|
// Create a new limiter allowing 20 burst events, recovering 1 token every l.cooloffDuration
|
|
lim := rate.NewLimiter(rate.Every(l.cooloffDuration), 20)
|
|
rateLimit = deviceRatelimit{Limiter: lim, lastUsed: time.Now()}
|
|
|
|
l.limitsMutex.Lock()
|
|
l.limits[caller] = rateLimit
|
|
l.limitsMutex.Unlock()
|
|
}
|
|
|
|
l.limitsMutex.Lock()
|
|
rateLimit.lastUsed = time.Now()
|
|
l.limits[caller] = rateLimit
|
|
l.limitsMutex.Unlock()
|
|
|
|
// Check if the user has got free resource slots for this request.
|
|
// If they don't then we'll try to wait until one is. If the
|
|
// context is canceled/done, return an error.
|
|
if err := rateLimit.Wait(req.Context()); err != nil {
|
|
// We hit the rate limit. Tell the client to back off.
|
|
return &util.JSONResponse{
|
|
Code: http.StatusRequestTimeout,
|
|
JSON: jsonerror.Unknown(fmt.Sprintf("Request timed out: %s", err)),
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|