diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index ff61ea6c8..5e8e5875c 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -78,19 +78,19 @@ const ( ) type DendriteMonolith struct { - logger logrus.Logger - baseDendrite *base.BaseDendrite - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - userAPI userapiAPI.UserInternalAPI - federationAPI api.FederationInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayRetriever RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -167,6 +167,152 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { } } +func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { + var nodeKey gomatrixserverlib.ServerName + if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { + hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) + } else { + nodeKey = userID.Domain() + } + } else { + hexKey, decodeErr := hex.DecodeString(nodeID) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) + } else { + nodeKey = gomatrixserverlib.ServerName(nodeID) + } + } + + return nodeKey, nil +} + +func updateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI api.FederationInternalAPI, +) { + // Get the current relay list + request := api.P2PQueryRelayServersRequest{Server: node} + response := api.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := api.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := api.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := api.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} + +func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { + relays := []gomatrixserverlib.ServerName{} + for _, uri := range strings.Split(uris, ",") { + uri = strings.TrimSpace(uri) + if len(uri) == 0 { + continue + } + + nodeKey, err := getServerKeyFromString(uri) + if err != nil { + logrus.Errorf(err.Error()) + continue + } + relays = append(relays, nodeKey) + } + + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + if string(nodeKey) == m.PublicKey() { + logrus.Infof("Setting own relay servers to: %v", relays) + m.relayRetriever.SetRelayServers(relays) + } else { + updateNodeRelayServers( + gomatrixserverlib.ServerName(nodeKey), + relays, + m.baseDendrite.Context(), + m.federationAPI, + ) + } +} + +func (m *DendriteMonolith) GetRelayServers(nodeID string) string { + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return "" + } + + relaysString := "" + if string(nodeKey) == m.PublicKey() { + relays := m.relayRetriever.GetRelayServers() + + for i, relay := range relays { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } else { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + response := api.P2PQueryRelayServersResponse{} + err := m.federationAPI.P2PQueryRelayServers(m.baseDendrite.Context(), &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + return "" + } + + for i, relay := range response.RelayServers { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } + + return relaysString +} + func (m *DendriteMonolith) DisconnectType(peertype int) { for _, p := range m.PineconeRouter.Peers() { if int(peertype) == p.PeerType { @@ -454,28 +600,28 @@ func (m *DendriteMonolith) Start() { } }() - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - stopRelayServerSync := make(chan bool) + stopRelayServerSync := make(chan bool) - relayRetriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), - FederationAPI: m.federationAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - running: *atomic.NewBool(false), - } - relayRetriever.InitializeRelayServers(eLog) + eLog := logrus.WithField("pinecone", "events") + m.relayRetriever = RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + quit: stopRelayServerSync, + } + m.relayRetriever.InitializeRelayServers(eLog) + + go func(ch <-chan pineconeEvents.Event) { for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: - if !relayRetriever.running.Load() { - go relayRetriever.SyncRelayServers(stopRelayServerSync) - } + m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -495,7 +641,7 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.baseDendrite.Close() + _ = m.baseDendrite.Close() m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() @@ -511,32 +657,68 @@ type RelayServerRetriever struct { relayServersQueried map[gomatrixserverlib.ServerName]bool queriedServersMutex sync.Mutex running atomic.Bool + quit <-chan bool } -func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.ServerName)} response := api.P2PQueryRelayServersResponse{} - err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + err := r.FederationAPI.P2PQueryRelayServers(r.Context, &request, &response) if err != nil { eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() for _, server := range response.RelayServers { - m.relayServersQueried[server] = false + r.relayServersQueried[server] = false } eLog.Infof("Registered relay servers: %v", response.RelayServers) } -func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { - defer m.running.Store(false) +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) t := time.NewTimer(relayServerRetryInterval) for { relayServersToQuery := []gomatrixserverlib.ServerName{} func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - for server, complete := range m.relayServersQueried { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { if !complete { relayServersToQuery = append(relayServersToQuery, server) } @@ -544,9 +726,10 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { }() if len(relayServersToQuery) == 0 { // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") return } - m.queryRelayServers(relayServersToQuery) + r.queryRelayServers(relayServersToQuery) t.Reset(relayServerRetryInterval) select { @@ -560,30 +743,32 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { } } -func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() result := map[gomatrixserverlib.ServerName]bool{} - for server, queried := range m.relayServersQueried { + for server, queried := range r.relayServersQueried { result[server] = queried } return result } -func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("querying relay servers for any available transactions") +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.ServerName), false) if err != nil { return } - err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) if err == nil { func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - m.relayServersQueried[server] = true + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true }() // TODO : What happens if your relay receives new messages after this point? // Should you continue to check with them, or should they try and contact you? diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index edcf22bbe..3c8873e09 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "strings" "testing" "time" @@ -196,3 +197,126 @@ func TestMonolithStarts(t *testing.T) { monolith.PublicKey() monolith.Stop() } + +func TestMonolithSetRelayServers(t *testing.T) { + testCases := []struct { + name string + nodeID string + relays string + expectedRelays string + expectSelf bool + }{ + { + name: "assorted valid, invalid, empty & self keys", + nodeID: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: true, + }, + { + name: "invalid node key", + nodeID: "@invalid:notakey", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "", + expectSelf: false, + }, + { + name: "node is self", + nodeID: "self", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: false, + }, + } + + for _, tc := range testCases { + monolith := DendriteMonolith{} + monolith.Start() + + inputRelays := tc.relays + expectedRelays := tc.expectedRelays + if tc.expectSelf { + inputRelays += "," + monolith.PublicKey() + expectedRelays += "," + monolith.PublicKey() + } + nodeID := tc.nodeID + if nodeID == "self" { + nodeID = monolith.PublicKey() + } + + monolith.SetRelayServers(nodeID, inputRelays) + relays := monolith.GetRelayServers(nodeID) + monolith.Stop() + + if !containSameKeys(strings.Split(relays, ","), strings.Split(expectedRelays, ",")) { + t.Fatalf("%s: expected %s got %s", tc.name, expectedRelays, relays) + } + } +} + +func containSameKeys(expected []string, actual []string) bool { + if len(expected) != len(actual) { + return false + } + + for _, expectedKey := range expected { + hasMatch := false + for _, actualKey := range actual { + if actualKey == expectedKey { + hasMatch = true + } + } + + if !hasMatch { + return false + } + } + + return true +} + +func TestParseServerKey(t *testing.T) { + testCases := []struct { + name string + serverKey string + expectedErr bool + expectedKey gomatrixserverlib.ServerName + }{ + { + name: "valid userid as key", + serverKey: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "valid key", + serverKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "invalid userid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + { + name: "invalid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + } + + for _, tc := range testCases { + key, err := getServerKeyFromString(tc.serverKey) + if tc.expectedErr && err == nil { + t.Fatalf("%s: expected an error", tc.name) + } else if !tc.expectedErr && err != nil { + t.Fatalf("%s: didn't expect an error: %s", tc.name, err.Error()) + } + if tc.expectedKey != key { + t.Fatalf("%s: keys not equal. expected: %s got: %s", tc.name, tc.expectedKey, key) + } + } +} diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 417b08521..e4c0b2714 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -72,12 +72,26 @@ type RoomserverFederationAPI interface { } type P2PFederationAPI interface { - // Relay Server sync api used in the pinecone demos. + // Get the relay servers associated for the given server. P2PQueryRelayServers( ctx context.Context, request *P2PQueryRelayServersRequest, response *P2PQueryRelayServersResponse, ) error + + // Add relay server associations to the given server. + P2PAddRelayServers( + ctx context.Context, + request *P2PAddRelayServersRequest, + response *P2PAddRelayServersResponse, + ) error + + // Remove relay server associations from the given server. + P2PRemoveRelayServers( + ctx context.Context, + request *P2PRemoveRelayServersRequest, + response *P2PRemoveRelayServersResponse, + ) error } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver @@ -256,3 +270,19 @@ type P2PQueryRelayServersRequest struct { type P2PQueryRelayServersResponse struct { RelayServers []gomatrixserverlib.ServerName } + +type P2PAddRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersResponse struct { +} + +type P2PRemoveRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PRemoveRelayServersResponse struct { +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 552942f28..b9684f767 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -840,6 +840,36 @@ func (r *FederationInternalAPI) P2PQueryRelayServers( return nil } +// P2PAddRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PAddRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +// P2PRemoveRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PRemoveRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + func (r *FederationInternalAPI) shouldAttemptDirectFederation( destination gomatrixserverlib.ServerName, ) bool { diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index e8e0d00a3..e6e366f99 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -123,6 +123,47 @@ func TestQueryRelayServers(t *testing.T) { assert.Equal(t, len(relayServers), len(res.RelayServers)) } +func TestRemoveRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PRemoveRelayServersRequest{ + Server: server, + RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + } + res := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) + assert.NoError(t, err) + assert.Equal(t, 1, len(finalRelays)) + assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) +} + func TestPerformDirectoryLookup(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6130a567d..00e069d1e 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -25,6 +25,8 @@ const ( FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" + FederationAPIAddRelayServers = "/federationapi/addRelayServers" + FederationAPIRemoveRelayServers = "/federationapi/removeRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -522,3 +524,25 @@ func (h *httpFederationInternalAPI) P2PQueryRelayServers( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "AddRelayServers", h.federationAPIURL+FederationAPIAddRelayServers, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpFederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "RemoveRelayServers", h.federationAPIURL+FederationAPIRemoveRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 51350916d..12e6db9fa 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -410,34 +410,49 @@ func (oq *destinationQueue) nextTransaction( defer cancel() relayServers := oq.statistics.KnownRelayServers() - if oq.statistics.AssumedOffline() && len(relayServers) > 0 { - sendMethod = statistics.SendViaRelay - relaySuccess := false - logrus.Infof("Sending to relay servers: %v", relayServers) - // TODO : how to pass through actual userID here?!?!?!?! - userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) - if userErr != nil { - return userErr, sendMethod - } - - // Attempt sending to each known relay server. - for _, relayServer := range relayServers { - _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) - if relayErr != nil { - err = relayErr - } else { - // If sending to one of the relay servers succeeds, consider the send successful. - relaySuccess = true - } - } - - // Clear the error if sending to any of the relay servers succeeded. - if relaySuccess { - err = nil - } - } else { + hasRelayServers := len(relayServers) > 0 + shouldSendToRelays := oq.statistics.AssumedOffline() && hasRelayServers + if !shouldSendToRelays { sendMethod = statistics.SendDirect _, err = oq.client.SendTransaction(ctx, t) + } else { + // Try sending directly to the destination first in case they came back online. + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + if err != nil { + // The destination is still offline, try sending to relays. + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + + // TODO : what about if the dest comes back online but can't see their relay? + // How do I sync with the dest in that case? + // Should change the database to have a "relay success" flag on events and if + // I see the node back online, maybe directly send through the backlog of events + // with "relay success"... could lead to duplicate events, but only those that + // I sent. And will lead to a much more consistent experience. + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } } switch errResponse := err.(type) { case nil: diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 36e2ccbc2..bccfb3428 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -923,7 +923,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { assert.NoError(t, err) check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() == 1 { + if fc.txCount.Load() >= 1 { if fc.txRelayCount.Load() == 1 { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -962,7 +962,7 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { assert.NoError(t, err) check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() == 1 { + if fc.txCount.Load() >= 1 { if fc.txRelayCount.Load() == 1 { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 866c09336..e29e3b140 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -164,6 +164,8 @@ func (s *ServerStatistics) Success(method SendMethod) { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) } } + + s.removeAssumedOffline() } } diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 6c198018d..cd7d90562 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -49,7 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.Defaults(10) } c.FederationMaxRetries = 16 - c.P2PFederationRetriesUntilAssumedOffline = 2 + c.P2PFederationRetriesUntilAssumedOffline = 1 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index cc9e1e8fd..de0dc54eb 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -399,6 +399,33 @@ func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( return nil } +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + for i, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + d.relayServers[serverName] = append( + d.relayServers[serverName][:i], + d.relayServers[serverName][i+1:]..., + ) + break + } + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return nil, nil } @@ -431,10 +458,6 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context. return nil } -func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { - return nil -} - func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { return nil }