diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 7a47c3610..3860f824e 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -167,43 +167,123 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { } } -func (m *DendriteMonolith) SetRelayServer(nodeKey string, uri string) { +func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { + var nodeKey gomatrixserverlib.ServerName + if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { + hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) + } else { + nodeKey = userID.Domain() + } + } else { + hexKey, decodeErr := hex.DecodeString(nodeID) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) + } else { + nodeKey = gomatrixserverlib.ServerName(nodeID) + } + } + + return nodeKey, nil +} + +func updateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI api.FederationInternalAPI, +) { + // Get the current relay list + request := api.P2PQueryRelayServersRequest{Server: node} + response := api.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := api.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := api.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := api.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} + +func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { relays := []gomatrixserverlib.ServerName{} - for _, uri := range strings.Split(uri, ",") { + for _, uri := range strings.Split(uris, ",") { uri = strings.TrimSpace(uri) if len(uri) == 0 { continue } - if userID, err := gomatrixserverlib.NewUserID(uri, false); err == nil { - hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) - if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { - logrus.Warnf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) - continue - } - relays = append(relays, userID.Domain()) - } else { - hexKey, decodeErr := hex.DecodeString(uri) - if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { - logrus.Warnf("Relay server uri is not a valid ed25519 public key: %v", uri) - continue - } - relays = append(relays, gomatrixserverlib.ServerName(uri)) + nodeKey, err := getServerKeyFromString(uri) + if err != nil { + logrus.Errorf(err.Error()) + continue } + relays = append(relays, nodeKey) } - if nodeKey == m.PublicKey() { - logrus.Infof("Setting relay servers to: %v", relays) + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + if string(nodeKey) == m.PublicKey() { + logrus.Infof("Setting own relay servers to: %v", relays) m.relayRetriever.SetRelayServers(relays) } else { - // TODO: add relay/s for other nodes + updateNodeRelayServers( + gomatrixserverlib.ServerName(nodeKey), + relays, + m.baseDendrite.Context(), + m.federationAPI, + ) } } -func (m *DendriteMonolith) GetRelayServers(nodeKey string) string { - if nodeKey == m.PublicKey() { +func (m *DendriteMonolith) GetRelayServers(nodeID string) string { + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return "" + } + + relaysString := "" + if string(nodeKey) == m.PublicKey() { relays := m.relayRetriever.GetRelayServers() - relaysString := "" for i, relay := range relays { if i != 0 { @@ -212,13 +292,25 @@ func (m *DendriteMonolith) GetRelayServers(nodeKey string) string { } relaysString += string(relay) } - - return relaysString } else { - // TODO: return relays for other nodes + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + response := api.P2PQueryRelayServersResponse{} + err := m.federationAPI.P2PQueryRelayServers(m.baseDendrite.Context(), &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + return "" + } + + for i, relay := range response.RelayServers { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } } - return "" + return relaysString } func (m *DendriteMonolith) DisconnectType(peertype int) { @@ -548,7 +640,7 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.baseDendrite.Close() + _ = m.baseDendrite.Close() m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() @@ -585,19 +677,13 @@ func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { } func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. r.queriedServersMutex.Lock() defer r.queriedServersMutex.Unlock() r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) for _, server := range servers { - request := api.P2PAddRelayServersRequest{ - Server: gomatrixserverlib.ServerName(r.ServerName), - RelayServers: servers, - } - response := api.P2PAddRelayServersResponse{} - err := r.FederationAPI.P2PAddRelayServers(r.Context, &request, &response) - if err != nil { - logrus.Warnf("Failed adding relay servers for this node: %s", err.Error()) - } r.relayServersQueried[server] = false } diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 1f1ce71a4..e4c0b2714 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -85,6 +85,13 @@ type P2PFederationAPI interface { request *P2PAddRelayServersRequest, response *P2PAddRelayServersResponse, ) error + + // Remove relay server associations from the given server. + P2PRemoveRelayServers( + ctx context.Context, + request *P2PRemoveRelayServersRequest, + response *P2PRemoveRelayServersResponse, + ) error } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver @@ -271,3 +278,11 @@ type P2PAddRelayServersRequest struct { type P2PAddRelayServersResponse struct { } + +type P2PRemoveRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PRemoveRelayServersResponse struct { +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index cfc96cdb0..b9684f767 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -855,6 +855,21 @@ func (r *FederationInternalAPI) P2PAddRelayServers( return nil } +// P2PRemoveRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PRemoveRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + func (r *FederationInternalAPI) shouldAttemptDirectFederation( destination gomatrixserverlib.ServerName, ) bool { diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index e8e0d00a3..e6e366f99 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -123,6 +123,47 @@ func TestQueryRelayServers(t *testing.T) { assert.Equal(t, len(relayServers), len(res.RelayServers)) } +func TestRemoveRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PRemoveRelayServersRequest{ + Server: server, + RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + } + res := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) + assert.NoError(t, err) + assert.Equal(t, 1, len(finalRelays)) + assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) +} + func TestPerformDirectoryLookup(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index e6f42335d..00e069d1e 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -26,6 +26,7 @@ const ( FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" FederationAPIAddRelayServers = "/federationapi/addRelayServers" + FederationAPIRemoveRelayServers = "/federationapi/removeRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -534,3 +535,14 @@ func (h *httpFederationInternalAPI) P2PAddRelayServers( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "RemoveRelayServers", h.federationAPIURL+FederationAPIRemoveRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index cc9e1e8fd..de0dc54eb 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -399,6 +399,33 @@ func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( return nil } +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + for i, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + d.relayServers[serverName] = append( + d.relayServers[serverName][:i], + d.relayServers[serverName][i+1:]..., + ) + break + } + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return nil, nil } @@ -431,10 +458,6 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context. return nil } -func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { - return nil -} - func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { return nil }