diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index 27720369e..e8a29e418 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -68,12 +68,14 @@ type P2PMonolith struct { EventChannel chan pineconeEvents.Event RelayRetriever relay.RelayServerRetriever - dendrite setup.Monolith - port int - httpMux *mux.Router - pineconeMux *mux.Router - listener net.Listener - httpListenAddr string + dendrite setup.Monolith + port int + httpMux *mux.Router + pineconeMux *mux.Router + httpServer *http.Server + listener net.Listener + httpListenAddr string + stopHandlingEvents chan bool } func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, cacheDir string, dbPrefix string) *config.Dendrite { @@ -199,8 +201,10 @@ func (p *P2PMonolith) StartMonolith() { } func (p *P2PMonolith) Stop() { + logrus.Info("Stopping monolith") _ = p.BaseDendrite.Close() p.WaitForShutdown() + logrus.Info("Stopped monolith") } func (p *P2PMonolith) WaitForShutdown() { @@ -209,6 +213,16 @@ func (p *P2PMonolith) WaitForShutdown() { } func (p *P2PMonolith) closeAllResources() { + logrus.Info("Closing monolith resources") + if p.httpServer != nil { + p.httpServer.Shutdown(context.Background()) + } + + select { + case p.stopHandlingEvents <- true: + default: + } + if p.listener != nil { _ = p.listener.Close() } @@ -224,6 +238,7 @@ func (p *P2PMonolith) closeAllResources() { if p.Router != nil { _ = p.Router.Close() } + logrus.Info("Monolith resources closed") } func (p *P2PMonolith) Addr() string { @@ -280,7 +295,7 @@ func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, func (p *P2PMonolith) startHTTPServers() { go func() { // Build both ends of a HTTP multiplex. - httpServer := &http.Server{ + p.httpServer = &http.Server{ Addr: ":0", TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, ReadTimeout: 10 * time.Second, @@ -296,12 +311,13 @@ func (p *P2PMonolith) startHTTPServers() { pubkeyString := hex.EncodeToString(pubkey[:]) logrus.Info("Listening on ", pubkeyString) - switch httpServer.Serve(p.Sessions.Protocol(SessionProtocol)) { + switch p.httpServer.Serve(p.Sessions.Protocol(SessionProtocol)) { case net.ErrClosed, http.ErrServerClosed: logrus.Info("Stopped listening on ", pubkeyString) default: logrus.Error("Stopped listening on ", pubkeyString) } + logrus.Info("Stopped goroutine listening on ", pubkeyString) }() p.httpListenAddr = fmt.Sprintf(":%d", p.port) @@ -313,10 +329,12 @@ func (p *P2PMonolith) startHTTPServers() { default: logrus.Error("Stopped listening on ", p.httpListenAddr) } + logrus.Info("Stopped goroutine listening on ", p.httpListenAddr) }() } func (p *P2PMonolith) startEventHandler() { + p.stopHandlingEvents = make(chan bool) stopRelayServerSync := make(chan bool) eLog := logrus.WithField("pinecone", "events") p.RelayRetriever = relay.NewRelayServerRetriever( @@ -329,25 +347,40 @@ func (p *P2PMonolith) startEventHandler() { p.RelayRetriever.InitializeRelayServers(eLog) go func(ch <-chan pineconeEvents.Event) { - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - p.RelayRetriever.StartSync() - case pineconeEvents.PeerRemoved: - if p.RelayRetriever.IsRunning() && p.Router.TotalPeerCount() == 0 { - stopRelayServerSync <- true - } - case pineconeEvents.BroadcastReceived: - // eLog.Info("Broadcast received from: ", e.PeerID) + for { + select { + case event := <-ch: + switch e := event.(type) { + case pineconeEvents.PeerAdded: + p.RelayRetriever.StartSync() + case pineconeEvents.PeerRemoved: + if p.RelayRetriever.IsRunning() && p.Router.TotalPeerCount() == 0 { + // NOTE: Don't block on channel + select { + case stopRelayServerSync <- true: + default: + } + } + case pineconeEvents.BroadcastReceived: + // eLog.Info("Broadcast received from: ", e.PeerID) - req := &federationAPI.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + req := &federationAPI.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &federationAPI.PerformWakeupServersResponse{} + if err := p.dendrite.FederationAPI.PerformWakeupServers(p.BaseDendrite.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } } - res := &federationAPI.PerformWakeupServersResponse{} - if err := p.dendrite.FederationAPI.PerformWakeupServers(p.BaseDendrite.Context(), req, res); err != nil { - eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) + case <-p.stopHandlingEvents: + logrus.Info("Stopping processing pinecone events") + // NOTE: Don't block on channel + select { + case stopRelayServerSync <- true: + default: } - default: + logrus.Info("Stopped processing pinecone events") + return } } }(p.EventChannel) diff --git a/cmd/dendrite-demo-pinecone/relay/retriever.go b/cmd/dendrite-demo-pinecone/relay/retriever.go index 1b5c617ef..6b34f6416 100644 --- a/cmd/dendrite-demo-pinecone/relay/retriever.go +++ b/cmd/dendrite-demo-pinecone/relay/retriever.go @@ -38,7 +38,7 @@ type RelayServerRetriever struct { relayServersQueried map[gomatrixserverlib.ServerName]bool queriedServersMutex sync.Mutex running atomic.Bool - quit <-chan bool + quit chan bool } func NewRelayServerRetriever( @@ -46,7 +46,7 @@ func NewRelayServerRetriever( serverName gomatrixserverlib.ServerName, federationAPI federationAPI.FederationInternalAPI, relayAPI relayServerAPI.RelayInternalAPI, - quit <-chan bool, + quit chan bool, ) RelayServerRetriever { return RelayServerRetriever{ ctx: ctx, @@ -151,6 +151,7 @@ func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { if !t.Stop() { <-t.C } + logrus.Info("Stopped relay server retriever") return case <-t.C: } diff --git a/cmd/dendrite-demo-pinecone/relay/retriever_test.go b/cmd/dendrite-demo-pinecone/relay/retriever_test.go index 8f86a3770..6c4c3a529 100644 --- a/cmd/dendrite-demo-pinecone/relay/retriever_test.go +++ b/cmd/dendrite-demo-pinecone/relay/retriever_test.go @@ -60,7 +60,7 @@ func TestRelayRetrieverInitialization(t *testing.T) { "server", &FakeFedAPI{}, &FakeRelayAPI{}, - make(<-chan bool), + make(chan bool), ) retriever.InitializeRelayServers(logrus.WithField("test", "relay")) @@ -74,7 +74,7 @@ func TestRelayRetrieverSync(t *testing.T) { "server", &FakeFedAPI{}, &FakeRelayAPI{}, - make(<-chan bool), + make(chan bool), ) retriever.InitializeRelayServers(logrus.WithField("test", "relay"))