diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 7a81b9854..00e444510 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -82,8 +82,19 @@ type Database interface { } type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error } diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go index 07cb4ce2e..b41211551 100644 --- a/federationapi/storage/tables/relay_servers_table_test.go +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -92,6 +92,54 @@ func TestShouldInsertRelayServers(t *testing.T) { }) } +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + func TestShouldDeleteCorrectRelayServers(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -118,20 +166,20 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) } - updatedExpectedRelayServers := []gomatrixserverlib.ServerName{server3} + expectedRelayServers := []gomatrixserverlib.ServerName{server3} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, updatedExpectedRelayServers) { - t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers) + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) } relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, updatedExpectedRelayServers) { - t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers) + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) } }) }