mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-17 02:53:11 -06:00
Cleanup relay servers table and add batch deletion to sqlite
This commit is contained in:
parent
98d4e4f89b
commit
16eb9e4e49
|
|
@ -18,6 +18,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -44,7 +45,7 @@ const selectRelayServersSQL = "" +
|
|||
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
const deleteRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
|
||||
|
||||
const deleteAllRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
|
@ -118,13 +119,9 @@ func (s *relayServersStatements) DeleteRelayServers(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
for _, relayServer := range relayServers {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteAllRelayServers(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
|
@ -50,10 +51,10 @@ const deleteAllRelayServersSQL = "" +
|
|||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
type relayServersStatements struct {
|
||||
db *sql.DB
|
||||
insertRelayServersStmt *sql.Stmt
|
||||
selectRelayServersStmt *sql.Stmt
|
||||
deleteRelayServersStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
insertRelayServersStmt *sql.Stmt
|
||||
selectRelayServersStmt *sql.Stmt
|
||||
// deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
deleteAllRelayServersStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +70,6 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro
|
|||
return s, sqlutil.StatementList{
|
||||
{&s.insertRelayServersStmt, insertRelayServersSQL},
|
||||
{&s.selectRelayServersStmt, selectRelayServersSQL},
|
||||
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
|
||||
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
|
@ -118,13 +118,21 @@ func (s *relayServersStatements) DeleteRelayServers(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
for _, relayServer := range relayServers {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
|
||||
return err
|
||||
}
|
||||
deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
|
||||
deleteStmt, err := s.db.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
params := make([]interface{}, len(relayServers)+1)
|
||||
params[0] = serverName
|
||||
for i, v := range relayServers {
|
||||
params[i+1] = v
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteAllRelayServers(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ const (
|
|||
server1 = "server1"
|
||||
server2 = "server2"
|
||||
server3 = "server3"
|
||||
server4 = "server4"
|
||||
)
|
||||
|
||||
type RelayServersDatabase struct {
|
||||
|
|
@ -96,13 +97,14 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
|
||||
relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
|
||||
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
|
||||
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
|
||||
err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
|
@ -111,21 +113,25 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
|
||||
}
|
||||
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
|
||||
}
|
||||
|
||||
expectedRelayServers1 := []gomatrixserverlib.ServerName{server3}
|
||||
updatedExpectedRelayServers := []gomatrixserverlib.ServerName{server3}
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers1) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
|
||||
if !Equal(relayServers, updatedExpectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
|
||||
}
|
||||
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
if !Equal(relayServers, updatedExpectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", updatedExpectedRelayServers, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue