Cleanup relay servers table and add batch deletion to sqlite

This commit is contained in:
Devon Hudson 2023-01-13 16:00:36 -07:00
parent 98d4e4f89b
commit 16eb9e4e49
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
3 changed files with 38 additions and 27 deletions

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -44,7 +45,7 @@ const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" + const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
const deleteAllRelayServersSQL = "" + const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1" "DELETE FROM federationsender_relay_servers WHERE server_name = $1"
@ -118,13 +119,9 @@ func (s *relayServersStatements) DeleteRelayServers(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName,
) error { ) error {
for _, relayServer := range relayServers { stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { return err
return err
}
}
return nil
} }
func (s *relayServersStatements) DeleteAllRelayServers( func (s *relayServersStatements) DeleteAllRelayServers(

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -50,10 +51,10 @@ const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1" "DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct { type relayServersStatements struct {
db *sql.DB db *sql.DB
insertRelayServersStmt *sql.Stmt insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
deleteAllRelayServersStmt *sql.Stmt deleteAllRelayServersStmt *sql.Stmt
} }
@ -69,7 +70,6 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL}, {&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL}, {&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -118,13 +118,21 @@ func (s *relayServersStatements) DeleteRelayServers(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName,
) error { ) error {
for _, relayServer := range relayServers { deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) deleteStmt, err := s.db.Prepare(deleteSQL)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { if err != nil {
return err return err
}
} }
return nil
stmt := sqlutil.TxStmt(txn, deleteStmt)
params := make([]interface{}, len(relayServers)+1)
params[0] = serverName
for i, v := range relayServers {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
} }
func (s *relayServersStatements) DeleteAllRelayServers( func (s *relayServersStatements) DeleteAllRelayServers(

View file

@ -19,6 +19,7 @@ const (
server1 = "server1" server1 = "server1"
server2 = "server2" server2 = "server2"
server3 = "server3" server3 = "server3"
server4 = "server4"
) )
type RelayServersDatabase struct { type RelayServersDatabase struct {
@ -96,13 +97,14 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType) db, close := mustCreateRelayServersTable(t, dbType)
defer close() defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
@ -111,21 +113,25 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
} }
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{server3} updatedExpectedRelayServers := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
} }
if !Equal(relayServers, expectedRelayServers1) { if !Equal(relayServers, updatedExpectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
} }
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
} }
if !Equal(relayServers, expectedRelayServers) { if !Equal(relayServers, updatedExpectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
} }
}) })
} }