Fix race in pinecone monolith tests

This commit is contained in:
Devon Hudson 2022-12-21 09:13:09 -07:00
parent a91e33037c
commit f25986d8fd
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
2 changed files with 36 additions and 13 deletions

View file

@ -465,7 +465,7 @@ func (m *DendriteMonolith) Start() {
Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()),
FederationAPI: m.federationAPI,
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: monolith.RelayAPI,
running: *atomic.NewBool(false),
}
@ -515,9 +515,10 @@ func (m *DendriteMonolith) Stop() {
type RelayServerRetriever struct {
Context context.Context
ServerName gomatrixserverlib.ServerName
RelayServersQueried map[gomatrixserverlib.ServerName]bool
FederationAPI api.FederationInternalAPI
RelayAPI relayServerAPI.RelayInternalAPI
relayServersQueried map[gomatrixserverlib.ServerName]bool
queriedServersMutex sync.Mutex
running atomic.Bool
}
@ -529,7 +530,7 @@ func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
// TODO
}
for _, server := range response.RelayServers {
m.RelayServersQueried[server] = false
m.relayServersQueried[server] = false
}
eLog.Infof("Registered relay servers: %v", response.RelayServers)
@ -541,11 +542,15 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) {
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.RelayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
func() {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
for server, complete := range m.relayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
}
}()
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
@ -564,6 +569,17 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) {
}
}
func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
result := map[gomatrixserverlib.ServerName]bool{}
for server, queried := range m.relayServersQueried {
result[server] = queried
}
return result
}
func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
logrus.Info("querying relay servers for async_events")
for _, server := range relayServers {
@ -578,7 +594,11 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli
response := relayServerAPI.PerformRelayServerSyncResponse{}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response)
if err == nil {
m.RelayServersQueried[server] = true
func() {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
m.relayServersQueried[server] = true
}()
// TODO : What happens if your relay receives new messages after this point?
// Should you continue to check with them, or should they try and contact you?
// They could send a "new_async_events" message your way maybe?

View file

@ -150,32 +150,35 @@ func TestRelayRetrieverInitialization(t *testing.T) {
retriever := RelayServerRetriever{
Context: context.Background(),
ServerName: "server",
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
FederationAPI: &FakeFedAPI{},
RelayAPI: &FakeRelayAPI{},
}
retriever.InitializeRelayServers(logrus.WithField("test", "relay"))
assert.Equal(t, 2, len(retriever.RelayServersQueried))
relayServers := retriever.GetQueriedServerStatus()
assert.Equal(t, 2, len(relayServers))
}
func TestRelayRetrieverSync(t *testing.T) {
retriever := RelayServerRetriever{
Context: context.Background(),
ServerName: "server",
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
FederationAPI: &FakeFedAPI{},
RelayAPI: &FakeRelayAPI{},
}
retriever.InitializeRelayServers(logrus.WithField("test", "relay"))
assert.Equal(t, 2, len(retriever.RelayServersQueried))
relayServers := retriever.GetQueriedServerStatus()
assert.Equal(t, 2, len(relayServers))
stopRelayServerSync := make(chan bool)
go retriever.SyncRelayServers(stopRelayServerSync)
check := func(log poll.LogT) poll.Result {
for _, queried := range retriever.RelayServersQueried {
relayServers := retriever.GetQueriedServerStatus()
for _, queried := range relayServers {
if !queried {
return poll.Continue("waiting for all servers to be queried")
}