diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index 05643de93..7e2d7b20a 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -78,3 +78,40 @@ func TestPerformWakeupServers(t *testing.T) { assert.NoError(t, err) assert.False(t, offline) } + +func TestQueryRelayServers(t *testing.T) { + testDB := storage.NewFakeFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.AddRelayServersForServer(server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, 8, 3) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.QueryRelayServersRequest{ + Server: server, + } + res := api.QueryRelayServersResponse{} + err = fedAPI.QueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} diff --git a/federationapi/storage/fake_federation_db.go b/federationapi/storage/fake_federation_db.go index 39e40c0b6..0031ba871 100644 --- a/federationapi/storage/fake_federation_db.go +++ b/federationapi/storage/fake_federation_db.go @@ -329,6 +329,8 @@ func (d *FakeFederationDatabase) AddRelayServersForServer(serverName gomatrixser d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) } } + } else { + d.relayServers[serverName] = relayServers } return nil