From 16eb9e4e491852270b7d9f226a3c35eedd25dd2c Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 13 Jan 2023 16:00:36 -0700 Subject: [PATCH] Cleanup relay servers table and add batch deletion to sqlite --- .../storage/postgres/relay_servers_table.go | 13 ++++---- .../storage/sqlite3/relay_servers_table.go | 30 ++++++++++++------- .../tables/relay_servers_table_test.go | 22 +++++++++----- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go index e4a1e4e2c..f7267978f 100644 --- a/federationapi/storage/postgres/relay_servers_table.go +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -44,7 +45,7 @@ const selectRelayServersSQL = "" + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" const deleteRelayServersSQL = "" + - "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" const deleteAllRelayServersSQL = "" + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" @@ -118,13 +119,9 @@ func (s *relayServersStatements) DeleteRelayServers( serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName, ) error { - for _, relayServer := range relayServers { - stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) - if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { - return err - } - } - return nil + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err } func (s *relayServersStatements) DeleteAllRelayServers( diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go index 27eed7bc7..27c3cca2c 100644 --- a/federationapi/storage/sqlite3/relay_servers_table.go +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -50,10 +51,10 @@ const deleteAllRelayServersSQL = "" + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" type relayServersStatements struct { - db *sql.DB - insertRelayServersStmt *sql.Stmt - selectRelayServersStmt *sql.Stmt - deleteRelayServersStmt *sql.Stmt + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic deleteAllRelayServersStmt *sql.Stmt } @@ -69,7 +70,6 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro return s, sqlutil.StatementList{ {&s.insertRelayServersStmt, insertRelayServersSQL}, {&s.selectRelayServersStmt, selectRelayServersSQL}, - {&s.deleteRelayServersStmt, deleteRelayServersSQL}, {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, }.Prepare(db) } @@ -118,13 +118,21 @@ func (s *relayServersStatements) DeleteRelayServers( serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName, ) error { - for _, relayServer := range relayServers { - stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) - if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { - return err - } + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err } - return nil + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *relayServersStatements) DeleteAllRelayServers( diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go index 4d29514e8..07cb4ce2e 100644 --- a/federationapi/storage/tables/relay_servers_table_test.go +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -19,6 +19,7 @@ const ( server1 = "server1" server2 = "server2" server3 = "server3" + server4 = "server4" ) type RelayServersDatabase struct { @@ -96,13 +97,14 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} - err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } @@ -111,21 +113,25 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { if err != nil { t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } - expectedRelayServers1 := []gomatrixserverlib.ServerName{server3} + updatedExpectedRelayServers := []gomatrixserverlib.ServerName{server3} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, expectedRelayServers1) { - t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + if !Equal(relayServers, updatedExpectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers) } relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, expectedRelayServers) { - t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + if !Equal(relayServers, updatedExpectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers) } }) }