mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-06 14:33:10 -06:00
Add a wrapper around compileACLRegex, introduce a compileRegexFunc,
cache regexes on startup
This commit is contained in:
parent
f4e77453cb
commit
f3430bb21d
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -43,6 +44,7 @@ type ServerACLDatabase interface {
|
|||
type ServerACLs struct {
|
||||
acls map[string]*serverACL // room ID -> ACL
|
||||
aclsMutex sync.RWMutex // protects the above
|
||||
compileRegexFunc func(orig string) (*regexp.Regexp, error)
|
||||
}
|
||||
|
||||
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||
|
|
@ -50,6 +52,8 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
|||
acls := &ServerACLs{
|
||||
acls: make(map[string]*serverACL),
|
||||
}
|
||||
// temporary regex function to use on start up
|
||||
acls.compileRegexFunc = cachedCompileACLRegex()
|
||||
// Look up all of the rooms that the current state server knows about.
|
||||
rooms, err := db.GetKnownRooms(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -67,6 +71,15 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
|||
for _, event := range events {
|
||||
acls.OnServerACLUpdate(event)
|
||||
}
|
||||
|
||||
// All regexes are built, replace the function to use.
|
||||
acls.compileRegexFunc = compileACLRegex
|
||||
|
||||
// Force a GC, as we potentially allocated hundreds of regexes
|
||||
// which can now be GCed, since we replaced the regex function
|
||||
// containing the cache.
|
||||
runtime.GC()
|
||||
|
||||
return acls
|
||||
}
|
||||
|
||||
|
|
@ -89,6 +102,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) {
|
|||
return regexp.Compile(escaped)
|
||||
}
|
||||
|
||||
// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching
|
||||
func cachedCompileACLRegex() func(string) (*regexp.Regexp, error) {
|
||||
// Be generous when creating the cache, as in reality
|
||||
// there are hundreds of servers in an ACL.
|
||||
cache := make(map[string]*regexp.Regexp, 100)
|
||||
return func(orig string) (*regexp.Regexp, error) {
|
||||
re, ok := cache[orig]
|
||||
if ok {
|
||||
return re, nil
|
||||
}
|
||||
re, err := compileACLRegex(orig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache[orig] = re
|
||||
return re, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
|
||||
acls := &serverACL{}
|
||||
if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil {
|
||||
|
|
@ -100,14 +132,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
|
|||
// special characters and then replace * and ? with their regex counterparts.
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
||||
for _, orig := range acls.Allowed {
|
||||
if expr, err := compileACLRegex(orig); err != nil {
|
||||
if expr, err := s.compileRegexFunc(orig); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
||||
} else {
|
||||
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
||||
}
|
||||
}
|
||||
for _, orig := range acls.Denied {
|
||||
if expr, err := compileACLRegex(orig); err != nil {
|
||||
if expr, err := s.compileRegexFunc(orig); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
||||
} else {
|
||||
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,14 @@
|
|||
package acls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOpenACLsWithBlacklist(t *testing.T) {
|
||||
|
|
@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
|
|||
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}`
|
||||
)
|
||||
|
||||
type dummyACLDB struct{}
|
||||
|
||||
func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||
return []string{"1", "2"}, nil
|
||||
}
|
||||
|
||||
func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
|
||||
return []tables.StrippedEvent{
|
||||
{
|
||||
RoomID: "1",
|
||||
ContentValue: content1,
|
||||
},
|
||||
{
|
||||
RoomID: "2",
|
||||
ContentValue: content1,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestCachedRegex(t *testing.T) {
|
||||
db := dummyACLDB{}
|
||||
wantBannedServer := spec.ServerName("hello.world")
|
||||
|
||||
acls := NewServerACLs(db)
|
||||
|
||||
// Check that hello.world is banned in room 1
|
||||
banned := acls.IsServerBannedFromRoom(wantBannedServer, "1")
|
||||
assert.True(t, banned)
|
||||
|
||||
// Check that hello.world is banned in room 2
|
||||
banned = acls.IsServerBannedFromRoom(wantBannedServer, "2")
|
||||
assert.True(t, banned)
|
||||
|
||||
// Check that matrix.hello.world is banned in room 2
|
||||
banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2")
|
||||
assert.True(t, banned)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue