From f01ff850ab9bb09b6d61f7990840b6f3319a6e9c Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 26 Oct 2022 14:57:09 +0200 Subject: [PATCH] Add mutex, rename variables, update logging --- appservice/appservice.go | 7 +++++-- appservice/query/query.go | 32 +++++++++++++++++++------------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 9000adb1d..0c778b6ca 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "net/http" + "sync" "time" "github.com/gorilla/mux" @@ -58,8 +59,10 @@ func NewInternalAPI( // Create appserivce query API with an HTTP client that will be used for all // outbound and inbound requests (inbound only for the internal API) appserviceQueryAPI := &query.AppServiceQueryAPI{ - HTTPClient: client, - Cfg: &base.Cfg.AppServiceAPI, + HTTPClient: client, + Cfg: &base.Cfg.AppServiceAPI, + ProtocolCache: map[string]appserviceAPI.ASProtocolResponse{}, + CacheMu: sync.Mutex{}, } if len(base.Cfg.Derived.ApplicationServices) == 0 { diff --git a/appservice/query/query.go b/appservice/query/query.go index 126cf0e7a..83e31b755 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -23,6 +23,7 @@ import ( "net/http" "net/url" "strings" + "sync" "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" @@ -39,7 +40,8 @@ const userIDExistsPath = "/users/" type AppServiceQueryAPI struct { HTTPClient *http.Client Cfg *config.AppServiceAPI - protocolCache map[string]api.ASProtocolResponse + ProtocolCache map[string]api.ASProtocolResponse + CacheMu sync.Mutex } // RoomAliasExists performs a request to '/room/{roomAlias}' on all known @@ -208,7 +210,7 @@ func (a *AppServiceQueryAPI) Locations( } for _, as := range a.Cfg.Derived.ApplicationServices { - var proto []api.ASLocationResponse + var asLocations []api.ASLocationResponse params.Set("access_token", as.HSToken) url := as.URL + api.ASLocationPath @@ -216,12 +218,12 @@ func (a *AppServiceQueryAPI) Locations( url += "/" + req.Protocol } - if err := requestDo[[]api.ASLocationResponse](a.HTTPClient, url+"?"+params.Encode(), &proto); err != nil { - log.WithError(err).Error("unable to get protocolResponse from application service") + if err := requestDo[[]api.ASLocationResponse](a.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { + log.WithError(err).Error("unable to get 'locations' from application service") continue } - resp.Locations = append(resp.Locations, proto...) + resp.Locations = append(resp.Locations, asLocations...) } if len(resp.Locations) == 0 { @@ -243,7 +245,7 @@ func (a *AppServiceQueryAPI) User( } for _, as := range a.Cfg.Derived.ApplicationServices { - var proto []api.ASUserResponse + var asUsers []api.ASUserResponse params.Set("access_token", as.HSToken) url := as.URL + api.ASUserPath @@ -251,12 +253,12 @@ func (a *AppServiceQueryAPI) User( url += "/" + req.Protocol } - if err := requestDo[[]api.ASUserResponse](a.HTTPClient, url+"?"+params.Encode(), &proto); err != nil { - log.WithError(err).Error("unable to get protocolResponse from application service") + if err := requestDo[[]api.ASUserResponse](a.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { + log.WithError(err).Error("unable to get 'user' from application service") continue } - resp.Users = append(resp.Users, proto...) + resp.Users = append(resp.Users, asUsers...) } if len(resp.Users) == 0 { @@ -276,7 +278,9 @@ func (a *AppServiceQueryAPI) Protocols( // get a single protocol response if req.Protocol != "" { - if proto, ok := a.protocolCache[req.Protocol]; ok { + a.CacheMu.Lock() + defer a.CacheMu.Unlock() + if proto, ok := a.ProtocolCache[req.Protocol]; ok { resp.Exists = true resp.Protocols = map[string]api.ASProtocolResponse{ req.Protocol: proto, @@ -285,11 +289,10 @@ func (a *AppServiceQueryAPI) Protocols( } response := api.ASProtocolResponse{} - log.Debugf("XXX: getting single protocol") for _, as := range a.Cfg.Derived.ApplicationServices { var proto api.ASProtocolResponse if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+req.Protocol, &proto); err != nil { - logrus.WithError(err).Error("unable to get protocolResponse from application service") + logrus.WithError(err).Error("unable to get 'protocol' from application service") continue } @@ -337,7 +340,10 @@ func (a *AppServiceQueryAPI) Protocols( return nil } - a.protocolCache = response + a.CacheMu.Lock() + defer a.CacheMu.Unlock() + a.ProtocolCache = response + resp.Exists = true resp.Protocols = response return nil