Migration tweaks

This commit is contained in:
Neil Alexander 2022-11-07 10:22:13 +00:00
parent d058e052fc
commit 3b816f306e
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 36 additions and 17 deletions

View file

@ -45,18 +45,25 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables, Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables, Down: deltas.DownRenameTables,
}) })
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName) accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
} }
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName) devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
@ -99,13 +106,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
} }
m = sqlutil.NewMigrator(db) m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }

View file

@ -28,7 +28,18 @@ var serverNamesTables = []string{
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables { for _, table := range serverNamesTables {
q := fmt.Sprintf( q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", "SELECT name FROM sqlite_schema WHERE type='table' AND name=%s;",
pq.QuoteIdentifier(table),
)
var c int
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
fmt.Println("Error:", err)
continue
} else {
fmt.Println("HAPPY DAYS!", table)
}
q = fmt.Sprintf(
"ALTER TABLE %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table), pq.QuoteIdentifier(table),
) )
if _, err := tx.ExecContext(ctx, q); err != nil { if _, err := tx.ExecContext(ctx, q); err != nil {
@ -48,7 +59,7 @@ func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib
func DownServerNames(ctx context.Context, tx *sql.Tx) error { func DownServerNames(ctx context.Context, tx *sql.Tx) error {
for _, table := range serverNamesTables { for _, table := range serverNamesTables {
q := fmt.Sprintf( q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;", "ALTER TABLE %s DELETE COLUMN server_name;",
pq.QuoteIdentifier(table), pq.QuoteIdentifier(table),
) )
if _, err := tx.ExecContext(ctx, q); err != nil { if _, err := tx.ExecContext(ctx, q); err != nil {

View file

@ -54,14 +54,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName) accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
} }
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
devicesTable, err := NewSQLiteDevicesTable(db, serverName) devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
@ -102,6 +102,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
} }
m = sqlutil.NewMigrator(db)
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
AccountDatas: accountDataTable, AccountDatas: accountDataTable,
Accounts: accountsTable, Accounts: accountsTable,