diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go index 0c9bd015b..c4e5c75fd 100644 --- a/userapi/storage/postgres/deltas/2022110411000000_server_names.go +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -34,26 +34,6 @@ func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib if _, err := tx.ExecContext(ctx, q); err != nil { return fmt.Errorf("add server name to %q error: %w", table, err) } - q = fmt.Sprintf( - "UPDATE %s SET server_name = %s WHERE server_name = '';", - pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), - ) - if _, err := tx.ExecContext(ctx, q); err != nil { - return fmt.Errorf("write server names to %q error: %w", table, err) - } - } - return nil -} - -func DownServerNames(ctx context.Context, tx *sql.Tx) error { - for _, table := range serverNamesTables { - q := fmt.Sprintf( - "ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;", - pq.QuoteIdentifier(table), - ) - if _, err := tx.ExecContext(ctx, q); err != nil { - return fmt.Errorf("remove server name from %q error: %w", table, err) - } } return nil } diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 31c4f019c..92dc48081 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -45,6 +45,12 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } @@ -100,11 +106,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names", + Version: "userapi: server names populate", Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNames(ctx, txn, serverName) + return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, - Down: deltas.DownServerNames, }) if err = m.Up(base.Context()); err != nil { return nil, err diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go index ba0abfb9b..5e05d5cad 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -42,26 +42,6 @@ func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib if _, err := tx.ExecContext(ctx, q); err != nil { return fmt.Errorf("add server name to %q error: %w", table, err) } - q = fmt.Sprintf( - "UPDATE %s SET server_name = %s WHERE server_name = '';", - pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), - ) - if _, err := tx.ExecContext(ctx, q); err != nil { - return fmt.Errorf("write server names to %q error: %w", table, err) - } - } - return nil -} - -func DownServerNames(ctx context.Context, tx *sql.Tx) error { - for _, table := range serverNamesTables { - q := fmt.Sprintf( - "ALTER TABLE %s DELETE COLUMN server_name;", - pq.QuoteIdentifier(table), - ) - if _, err := tx.ExecContext(ctx, q); err != nil { - return fmt.Errorf("remove server name from %q error: %w", table, err) - } } return nil } diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..36b3f30e3 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -0,0 +1,36 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "SELECT name FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + continue + } + q = fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index b3b3231d2..85a1f7063 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -43,6 +43,12 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } @@ -98,11 +104,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names", + Version: "userapi: server names populate", Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNames(ctx, txn, serverName) + return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, - Down: deltas.DownServerNames, }) if err = m.Up(base.Context()); err != nil { return nil, err