diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 660f4f3bb..299824327 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "regexp" + "runtime" "strings" "sync" @@ -41,8 +42,9 @@ type ServerACLDatabase interface { } type ServerACLs struct { - acls map[string]*serverACL // room ID -> ACL - aclsMutex sync.RWMutex // protects the above + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above + compileRegexFunc func(orig string) (*regexp.Regexp, error) } func NewServerACLs(db ServerACLDatabase) *ServerACLs { @@ -50,6 +52,8 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { acls := &ServerACLs{ acls: make(map[string]*serverACL), } + // temporary regex function to use on start up + acls.compileRegexFunc = cachedCompileACLRegex() // Look up all of the rooms that the current state server knows about. rooms, err := db.GetKnownRooms(ctx) if err != nil { @@ -67,6 +71,15 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { for _, event := range events { acls.OnServerACLUpdate(event) } + + // All regexes are built, replace the function to use. + acls.compileRegexFunc = compileACLRegex + + // Force a GC, as we potentially allocated hundreds of regexes + // which can now be GCed, since we replaced the regex function + // containing the cache. + runtime.GC() + return acls } @@ -89,6 +102,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) { return regexp.Compile(escaped) } +// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching +func cachedCompileACLRegex() func(string) (*regexp.Regexp, error) { + // Be generous when creating the cache, as in reality + // there are hundreds of servers in an ACL. + cache := make(map[string]*regexp.Regexp, 100) + return func(orig string) (*regexp.Regexp, error) { + re, ok := cache[orig] + if ok { + return re, nil + } + re, err := compileACLRegex(orig) + if err != nil { + return nil, err + } + cache[orig] = re + return re, nil + } +} + func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { acls := &serverACL{} if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil { @@ -100,14 +132,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { // special characters and then replace * and ? with their regex counterparts. // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl for _, orig := range acls.Allowed { - if expr, err := compileACLRegex(orig); err != nil { + if expr, err := s.compileRegexFunc(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile allowed regex") } else { acls.allowedRegexes = append(acls.allowedRegexes, expr) } } for _, orig := range acls.Denied { - if expr, err := compileACLRegex(orig); err != nil { + if expr, err := s.compileRegexFunc(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile denied regex") } else { acls.deniedRegexes = append(acls.deniedRegexes, expr) diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index 9fb6a5581..bd59db4dd 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -15,8 +15,14 @@ package acls import ( + "context" "regexp" "testing" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/stretchr/testify/assert" ) func TestOpenACLsWithBlacklist(t *testing.T) { @@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) { t.Fatal("Expected qux.com:4567 to be allowed but wasn't") } } + +var ( + content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}` +) + +type dummyACLDB struct{} + +func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) { + return []string{"1", "2"}, nil +} + +func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { + return []tables.StrippedEvent{ + { + RoomID: "1", + ContentValue: content1, + }, + { + RoomID: "2", + ContentValue: content1, + }, + }, nil +} + +func TestCachedRegex(t *testing.T) { + db := dummyACLDB{} + wantBannedServer := spec.ServerName("hello.world") + + acls := NewServerACLs(db) + + // Check that hello.world is banned in room 1 + banned := acls.IsServerBannedFromRoom(wantBannedServer, "1") + assert.True(t, banned) + + // Check that hello.world is banned in room 2 + banned = acls.IsServerBannedFromRoom(wantBannedServer, "2") + assert.True(t, banned) + + // Check that matrix.hello.world is banned in room 2 + banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2") + assert.True(t, banned) +}