From f25986d8fd8a8aedabc355184063735af8c822bf Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 21 Dec 2022 09:13:09 -0700 Subject: [PATCH] Fix race in pinecone monolith tests --- build/gobind-pinecone/monolith.go | 36 ++++++++++++++++++++------ build/gobind-pinecone/monolith_test.go | 13 ++++++---- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 0576d6ecb..aaf35a667 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -465,7 +465,7 @@ func (m *DendriteMonolith) Start() { Context: context.Background(), ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), FederationAPI: m.federationAPI, - RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), RelayAPI: monolith.RelayAPI, running: *atomic.NewBool(false), } @@ -515,9 +515,10 @@ func (m *DendriteMonolith) Stop() { type RelayServerRetriever struct { Context context.Context ServerName gomatrixserverlib.ServerName - RelayServersQueried map[gomatrixserverlib.ServerName]bool FederationAPI api.FederationInternalAPI RelayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex running atomic.Bool } @@ -529,7 +530,7 @@ func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { // TODO } for _, server := range response.RelayServers { - m.RelayServersQueried[server] = false + m.relayServersQueried[server] = false } eLog.Infof("Registered relay servers: %v", response.RelayServers) @@ -541,11 +542,15 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { t := time.NewTimer(relayServerRetryInterval) for { relayServersToQuery := []gomatrixserverlib.ServerName{} - for server, complete := range m.RelayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + for server, complete := range m.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } } - } + }() if len(relayServersToQuery) == 0 { // All relay servers have been synced. return @@ -564,6 +569,17 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { } } +func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range m.relayServersQueried { + result[server] = queried + } + return result +} + func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { logrus.Info("querying relay servers for async_events") for _, server := range relayServers { @@ -578,7 +594,11 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli response := relayServerAPI.PerformRelayServerSyncResponse{} err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response) if err == nil { - m.RelayServersQueried[server] = true + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + m.relayServersQueried[server] = true + }() // TODO : What happens if your relay receives new messages after this point? // Should you continue to check with them, or should they try and contact you? // They could send a "new_async_events" message your way maybe? diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index 8810a8f04..0aeef3084 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -150,32 +150,35 @@ func TestRelayRetrieverInitialization(t *testing.T) { retriever := RelayServerRetriever{ Context: context.Background(), ServerName: "server", - RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), FederationAPI: &FakeFedAPI{}, RelayAPI: &FakeRelayAPI{}, } retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - assert.Equal(t, 2, len(retriever.RelayServersQueried)) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) } func TestRelayRetrieverSync(t *testing.T) { retriever := RelayServerRetriever{ Context: context.Background(), ServerName: "server", - RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), FederationAPI: &FakeFedAPI{}, RelayAPI: &FakeRelayAPI{}, } retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - assert.Equal(t, 2, len(retriever.RelayServersQueried)) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) stopRelayServerSync := make(chan bool) go retriever.SyncRelayServers(stopRelayServerSync) check := func(log poll.LogT) poll.Result { - for _, queried := range retriever.RelayServersQueried { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { if !queried { return poll.Continue("waiting for all servers to be queried") }