Cleanup relay servers table and add batch deletion to sqlite

This commit is contained in:
Devon Hudson 2023-01-13 16:00:36 -07:00
parent 98d4e4f89b
commit 16eb9e4e49
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
3 changed files with 38 additions and 27 deletions

View file

@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
@ -44,7 +45,7 @@ const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
@ -118,13 +119,9 @@ func (s *relayServersStatements) DeleteRelayServers(
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -50,10 +51,10 @@ const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
// deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
deleteAllRelayServersStmt *sql.Stmt
}
@ -69,7 +70,6 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro
return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db)
}
@ -118,13 +118,21 @@ func (s *relayServersStatements) DeleteRelayServers(
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
deleteStmt, err := s.db.Prepare(deleteSQL)
if err != nil {
return err
}
return nil
stmt := sqlutil.TxStmt(txn, deleteStmt)
params := make([]interface{}, len(relayServers)+1)
params[0] = serverName
for i, v := range relayServers {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(

View file

@ -19,6 +19,7 @@ const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
server4 = "server4"
)
type RelayServersDatabase struct {
@ -96,13 +97,14 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
@ -111,21 +113,25 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{server3}
updatedExpectedRelayServers := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
if !Equal(relayServers, updatedExpectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
if !Equal(relayServers, updatedExpectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
}
})
}