diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 45607b3f4..64038c84b 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -45,18 +45,25 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + Down: deltas.DownServerNames, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewPostgresAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) - } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) @@ -99,13 +106,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } m = sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names", - Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNames(ctx, txn, serverName) - }, - Down: deltas.DownServerNames, - }) + if err = m.Up(base.Context()); err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go index 0c9bd015b..b867530b9 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -28,7 +28,18 @@ var serverNamesTables = []string{ func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( - "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", + "SELECT name FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + fmt.Println("Error:", err) + continue + } else { + fmt.Println("HAPPY DAYS!", table) + } + q = fmt.Sprintf( + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", pq.QuoteIdentifier(table), ) if _, err := tx.ExecContext(ctx, q); err != nil { @@ -48,7 +59,7 @@ func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib func DownServerNames(ctx context.Context, tx *sql.Tx) error { for _, table := range serverNamesTables { q := fmt.Sprintf( - "ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;", + "ALTER TABLE %s DELETE COLUMN server_name;", pq.QuoteIdentifier(table), ) if _, err := tx.ExecContext(ctx, q); err != nil { diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index c078e5051..4a06444ff 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -54,14 +54,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } - accountDataTable, err := NewSQLiteAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) - } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) @@ -102,6 +102,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable,