mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-06 14:33:10 -06:00
Add a wrapper around compileACLRegex, introduce a compileRegexFunc,
cache regexes on startup
This commit is contained in:
parent
f4e77453cb
commit
f3430bb21d
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
|
@ -41,8 +42,9 @@ type ServerACLDatabase interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerACLs struct {
|
type ServerACLs struct {
|
||||||
acls map[string]*serverACL // room ID -> ACL
|
acls map[string]*serverACL // room ID -> ACL
|
||||||
aclsMutex sync.RWMutex // protects the above
|
aclsMutex sync.RWMutex // protects the above
|
||||||
|
compileRegexFunc func(orig string) (*regexp.Regexp, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
|
|
@ -50,6 +52,8 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
acls := &ServerACLs{
|
acls := &ServerACLs{
|
||||||
acls: make(map[string]*serverACL),
|
acls: make(map[string]*serverACL),
|
||||||
}
|
}
|
||||||
|
// temporary regex function to use on start up
|
||||||
|
acls.compileRegexFunc = cachedCompileACLRegex()
|
||||||
// Look up all of the rooms that the current state server knows about.
|
// Look up all of the rooms that the current state server knows about.
|
||||||
rooms, err := db.GetKnownRooms(ctx)
|
rooms, err := db.GetKnownRooms(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -67,6 +71,15 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
acls.OnServerACLUpdate(event)
|
acls.OnServerACLUpdate(event)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All regexes are built, replace the function to use.
|
||||||
|
acls.compileRegexFunc = compileACLRegex
|
||||||
|
|
||||||
|
// Force a GC, as we potentially allocated hundreds of regexes
|
||||||
|
// which can now be GCed, since we replaced the regex function
|
||||||
|
// containing the cache.
|
||||||
|
runtime.GC()
|
||||||
|
|
||||||
return acls
|
return acls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -89,6 +102,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) {
|
||||||
return regexp.Compile(escaped)
|
return regexp.Compile(escaped)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching
|
||||||
|
func cachedCompileACLRegex() func(string) (*regexp.Regexp, error) {
|
||||||
|
// Be generous when creating the cache, as in reality
|
||||||
|
// there are hundreds of servers in an ACL.
|
||||||
|
cache := make(map[string]*regexp.Regexp, 100)
|
||||||
|
return func(orig string) (*regexp.Regexp, error) {
|
||||||
|
re, ok := cache[orig]
|
||||||
|
if ok {
|
||||||
|
return re, nil
|
||||||
|
}
|
||||||
|
re, err := compileACLRegex(orig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cache[orig] = re
|
||||||
|
return re, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
|
func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
|
||||||
acls := &serverACL{}
|
acls := &serverACL{}
|
||||||
if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil {
|
if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil {
|
||||||
|
|
@ -100,14 +132,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
|
||||||
// special characters and then replace * and ? with their regex counterparts.
|
// special characters and then replace * and ? with their regex counterparts.
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
||||||
for _, orig := range acls.Allowed {
|
for _, orig := range acls.Allowed {
|
||||||
if expr, err := compileACLRegex(orig); err != nil {
|
if expr, err := s.compileRegexFunc(orig); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
||||||
} else {
|
} else {
|
||||||
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, orig := range acls.Denied {
|
for _, orig := range acls.Denied {
|
||||||
if expr, err := compileACLRegex(orig); err != nil {
|
if expr, err := s.compileRegexFunc(orig); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
||||||
} else {
|
} else {
|
||||||
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,14 @@
|
||||||
package acls
|
package acls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOpenACLsWithBlacklist(t *testing.T) {
|
func TestOpenACLsWithBlacklist(t *testing.T) {
|
||||||
|
|
@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
|
||||||
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}`
|
||||||
|
)
|
||||||
|
|
||||||
|
type dummyACLDB struct{}
|
||||||
|
|
||||||
|
func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
|
return []string{"1", "2"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
|
||||||
|
return []tables.StrippedEvent{
|
||||||
|
{
|
||||||
|
RoomID: "1",
|
||||||
|
ContentValue: content1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
RoomID: "2",
|
||||||
|
ContentValue: content1,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedRegex(t *testing.T) {
|
||||||
|
db := dummyACLDB{}
|
||||||
|
wantBannedServer := spec.ServerName("hello.world")
|
||||||
|
|
||||||
|
acls := NewServerACLs(db)
|
||||||
|
|
||||||
|
// Check that hello.world is banned in room 1
|
||||||
|
banned := acls.IsServerBannedFromRoom(wantBannedServer, "1")
|
||||||
|
assert.True(t, banned)
|
||||||
|
|
||||||
|
// Check that hello.world is banned in room 2
|
||||||
|
banned = acls.IsServerBannedFromRoom(wantBannedServer, "2")
|
||||||
|
assert.True(t, banned)
|
||||||
|
|
||||||
|
// Check that matrix.hello.world is banned in room 2
|
||||||
|
banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2")
|
||||||
|
assert.True(t, banned)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue