diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 818bcb223..017682e05 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "regexp" - "runtime" "strings" "sync" @@ -42,18 +41,21 @@ type ServerACLDatabase interface { } type ServerACLs struct { - acls map[string]*serverACL // room ID -> ACL - aclsMutex sync.RWMutex // protects the above - compileRegexFunc func(orig string) (*regexp.Regexp, error) + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above + aclRegexCache map[string]**regexp.Regexp // Cache from "serverName" -> pointer to a regex + aclRegexCacheMutex sync.RWMutex // protects the above } func NewServerACLs(db ServerACLDatabase) *ServerACLs { ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), + // Be generous when creating the cache, as in reality + // there are hundreds of servers in an ACL. + aclRegexCache: make(map[string]**regexp.Regexp, 100), } - // temporary regex function to use on start up - acls.compileRegexFunc = cachedCompileACLRegex() + // Look up all of the rooms that the current state server knows about. rooms, err := db.GetKnownRooms(ctx) if err != nil { @@ -72,14 +74,6 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { acls.OnServerACLUpdate(event) } - // All regexes are built, replace the function to use. - acls.compileRegexFunc = compileACLRegex - - // Force a GC, as we potentially allocated hundreds of regexes - // which can now be GCed, since we replaced the regex function - // containing the cache. - runtime.GC() - return acls } @@ -91,8 +85,8 @@ type ServerACL struct { type serverACL struct { ServerACL - allowedRegexes []*regexp.Regexp - deniedRegexes []*regexp.Regexp + allowedRegexes []**regexp.Regexp + deniedRegexes []**regexp.Regexp } func compileACLRegex(orig string) (*regexp.Regexp, error) { @@ -103,22 +97,22 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) { } // cachedCompileACLRegex is a wrapper around compileACLRegex with added caching -func cachedCompileACLRegex() func(string) (*regexp.Regexp, error) { - // Be generous when creating the cache, as in reality - // there are hundreds of servers in an ACL. - cache := make(map[string]*regexp.Regexp, 100) - return func(orig string) (*regexp.Regexp, error) { - re, ok := cache[orig] - if ok { - return re, nil - } - re, err := compileACLRegex(orig) - if err != nil { - return nil, err - } - cache[orig] = re +func (s *ServerACLs) cachedCompileACLRegex(orig string) (**regexp.Regexp, error) { + s.aclRegexCacheMutex.RLock() + re, ok := s.aclRegexCache[orig] + if ok { + s.aclRegexCacheMutex.RUnlock() return re, nil } + s.aclRegexCacheMutex.RUnlock() + compiled, err := compileACLRegex(orig) + if err != nil { + return nil, err + } + s.aclRegexCacheMutex.Lock() + defer s.aclRegexCacheMutex.Unlock() + s.aclRegexCache[orig] = &compiled + return &compiled, nil } func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { @@ -132,14 +126,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { // special characters and then replace * and ? with their regex counterparts. // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl for _, orig := range acls.Allowed { - if expr, err := s.compileRegexFunc(orig); err != nil { + if expr, err := s.cachedCompileACLRegex(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile allowed regex") } else { acls.allowedRegexes = append(acls.allowedRegexes, expr) } } for _, orig := range acls.Denied { - if expr, err := s.compileRegexFunc(orig); err != nil { + if expr, err := s.cachedCompileACLRegex(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile denied regex") } else { acls.deniedRegexes = append(acls.deniedRegexes, expr) @@ -187,14 +181,14 @@ func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID s // Check if the hostname matches one of the denied regexes. If it does then // the server is banned from the room. for _, expr := range acls.deniedRegexes { - if expr.MatchString(string(serverName)) { + if (*expr).MatchString(string(serverName)) { return true } } // Check if the hostname matches one of the allowed regexes. If it does then // the server is NOT banned from the room. for _, expr := range acls.allowedRegexes { - if expr.MatchString(string(serverName)) { + if (*expr).MatchString(string(serverName)) { return false } } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index bd59db4dd..efe1d2093 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -44,8 +44,8 @@ func TestOpenACLsWithBlacklist(t *testing.T) { ServerACL: ServerACL{ AllowIPLiterals: true, }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{denyRegex}, + allowedRegexes: []**regexp.Regexp{&allowRegex}, + deniedRegexes: []**regexp.Regexp{&denyRegex}, } if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { @@ -83,8 +83,8 @@ func TestDefaultACLsWithWhitelist(t *testing.T) { ServerACL: ServerACL{ AllowIPLiterals: false, }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{}, + allowedRegexes: []**regexp.Regexp{&allowRegex}, + deniedRegexes: []**regexp.Regexp{}, } if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {