Add a wrapper around compileACLRegex, introduce a compileRegexFunc,

cache regexes on startup
This commit is contained in:
Till Faelligen 2024-02-22 08:47:57 +01:00
parent f4e77453cb
commit f3430bb21d
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
2 changed files with 84 additions and 4 deletions

View file

@ -20,6 +20,7 @@ import (
"fmt"
"net"
"regexp"
"runtime"
"strings"
"sync"
@ -41,8 +42,9 @@ type ServerACLDatabase interface {
}
type ServerACLs struct {
acls map[string]*serverACL // room ID -> ACL
aclsMutex sync.RWMutex // protects the above
acls map[string]*serverACL // room ID -> ACL
aclsMutex sync.RWMutex // protects the above
compileRegexFunc func(orig string) (*regexp.Regexp, error)
}
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
@ -50,6 +52,8 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
acls := &ServerACLs{
acls: make(map[string]*serverACL),
}
// temporary regex function to use on start up
acls.compileRegexFunc = cachedCompileACLRegex()
// Look up all of the rooms that the current state server knows about.
rooms, err := db.GetKnownRooms(ctx)
if err != nil {
@ -67,6 +71,15 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
for _, event := range events {
acls.OnServerACLUpdate(event)
}
// All regexes are built, replace the function to use.
acls.compileRegexFunc = compileACLRegex
// Force a GC, as we potentially allocated hundreds of regexes
// which can now be GCed, since we replaced the regex function
// containing the cache.
runtime.GC()
return acls
}
@ -89,6 +102,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) {
return regexp.Compile(escaped)
}
// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching
func cachedCompileACLRegex() func(string) (*regexp.Regexp, error) {
// Be generous when creating the cache, as in reality
// there are hundreds of servers in an ACL.
cache := make(map[string]*regexp.Regexp, 100)
return func(orig string) (*regexp.Regexp, error) {
re, ok := cache[orig]
if ok {
return re, nil
}
re, err := compileACLRegex(orig)
if err != nil {
return nil, err
}
cache[orig] = re
return re, nil
}
}
func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
acls := &serverACL{}
if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil {
@ -100,14 +132,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
// special characters and then replace * and ? with their regex counterparts.
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
for _, orig := range acls.Allowed {
if expr, err := compileACLRegex(orig); err != nil {
if expr, err := s.compileRegexFunc(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile allowed regex")
} else {
acls.allowedRegexes = append(acls.allowedRegexes, expr)
}
}
for _, orig := range acls.Denied {
if expr, err := compileACLRegex(orig); err != nil {
if expr, err := s.compileRegexFunc(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile denied regex")
} else {
acls.deniedRegexes = append(acls.deniedRegexes, expr)

View file

@ -15,8 +15,14 @@
package acls
import (
"context"
"regexp"
"testing"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
)
func TestOpenACLsWithBlacklist(t *testing.T) {
@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
}
}
var (
content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}`
)
type dummyACLDB struct{}
func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
return []string{"1", "2"}, nil
}
func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
return []tables.StrippedEvent{
{
RoomID: "1",
ContentValue: content1,
},
{
RoomID: "2",
ContentValue: content1,
},
}, nil
}
func TestCachedRegex(t *testing.T) {
db := dummyACLDB{}
wantBannedServer := spec.ServerName("hello.world")
acls := NewServerACLs(db)
// Check that hello.world is banned in room 1
banned := acls.IsServerBannedFromRoom(wantBannedServer, "1")
assert.True(t, banned)
// Check that hello.world is banned in room 2
banned = acls.IsServerBannedFromRoom(wantBannedServer, "2")
assert.True(t, banned)
// Check that matrix.hello.world is banned in room 2
banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2")
assert.True(t, banned)
}