Remove AddMigration

Use WithTransaction
Add Dendrite version to table
This commit is contained in:
Till Faelligen 2022-03-15 13:48:38 +01:00
parent 673e0f601b
commit ffe666fef9
14 changed files with 72 additions and 67 deletions

View file

@ -85,7 +85,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(d.db) m := sqlutil.NewMigrator(d.db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "drop federationsender_rooms", Version: "drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable, Up: deltas.UpRemoveRoomsTable,
}) })

View file

@ -84,7 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(d.db) m := sqlutil.NewMigrator(d.db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "drop federationsender_rooms", Version: "drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable, Up: deltas.UpRemoveRoomsTable,
}) })

View file

@ -20,8 +20,24 @@ import (
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/matrix-org/dendrite/internal"
) )
const createDBMigrationsSQL = "" +
"CREATE TABLE IF NOT EXISTS db_migrations (" +
" version TEXT PRIMARY KEY," +
" time TEXT," +
" dendrite_version TEXT" +
");"
const insertVersionSQL = "" +
"INSERT INTO db_migrations (version, time, dendrite_version)" +
" VALUES ($1, $2, $3) " +
" ON CONFLICT(version) DO UPDATE SET dendrite_version = $4, time = $5"
const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
// Migration defines a migration to be run. // Migration defines a migration to be run.
type Migration struct { type Migration struct {
// Version is a simple name description/name of this migration // Version is a simple name description/name of this migration
@ -50,73 +66,68 @@ func NewMigrator(db *sql.DB) *Migrator {
} }
} }
// AddMigration adds new migrations to the list. // AddMigrations adds new migrations to the list.
// De-duplicates migrations by their version // De-duplicates migrations by their version
func (m *Migrator) AddMigration(migration Migration) { func (m *Migrator) AddMigrations(migrations ...Migration) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !m.knownMigrations[migration.Version] {
m.migrations = append(m.migrations, migration)
m.knownMigrations[migration.Version] = true
}
}
// AddMigrations is a convenience method to add migrations
func (m *Migrator) AddMigrations(migrations ...Migration) {
for _, mig := range migrations { for _, mig := range migrations {
m.AddMigration(mig) if !m.knownMigrations[mig.Version] {
m.migrations = append(m.migrations, mig)
m.knownMigrations[mig.Version] = true
}
} }
} }
// Up executes all migrations // Up executes all migrations
func (m *Migrator) Up(ctx context.Context) error { func (m *Migrator) Up(ctx context.Context) error {
var err error var (
err error
dendriteVersion = internal.VersionString()
)
// ensure there is a table for known migrations // ensure there is a table for known migrations
executedMigrations, err := m.ExecutedMigrations(ctx) executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil { if err != nil {
return fmt.Errorf("unable to create/get migrations: %w", err) return fmt.Errorf("unable to create/get migrations: %w", err)
} }
txn, err := m.db.BeginTx(ctx, nil) return WithTransaction(m.db, func(txn *sql.Tx) error {
if err != nil {
return fmt.Errorf("unable to begin transaction: %w", err)
}
defer func() {
if err != nil {
_ = txn.Rollback()
}
}()
for i := range m.migrations { for i := range m.migrations {
now := time.Now().UTC().Format(time.RFC3339)
migration := m.migrations[i] migration := m.migrations[i]
if !executedMigrations[migration.Version] { if !executedMigrations[migration.Version] {
err = migration.Up(ctx, txn) err = migration.Up(ctx, txn)
if err != nil { if err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
} }
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC().Format(time.RFC3339)) _, err = txn.ExecContext(ctx, insertVersionSQL,
migration.Version,
now,
dendriteVersion,
dendriteVersion,
now,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err) return fmt.Errorf("unable to insert executed migrations: %w", err)
} }
} }
} }
if err = txn.Commit(); err != nil {
return fmt.Errorf("unable to commit transaction: %w", err)
}
return nil return nil
})
} }
// ExecutedMigrations returns a map with already executed migrations // ExecutedMigrations returns a map with already executed migrations
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) { func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
result := make(map[string]bool) result := make(map[string]bool)
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );") _, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create db_migrations: %w", err) return nil, fmt.Errorf("unable to create db_migrations: %w", err)
} }
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations") rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query db_migrations: %w", err) return nil, fmt.Errorf("unable to query db_migrations: %w", err)
} }
defer rows.Close() // nolint: errcheck
var version string var version string
for rows.Next() { for rows.Next() {
if err := rows.Scan(&version); err != nil { if err := rows.Scan(&version); err != nil {

View file

@ -68,7 +68,7 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
if err == nil { if err == nil {
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "refactor key changes", Version: "refactor key changes",
Up: deltas.UpRefactorKeyChanges, Up: deltas.UpRefactorKeyChanges,
}) })

View file

@ -65,7 +65,7 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
if err == nil { if err == nil {
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "refactor key changes", Version: "refactor key changes",
Up: deltas.UpRefactorKeyChanges, Up: deltas.UpRefactorKeyChanges,
}) })

View file

@ -167,7 +167,7 @@ func createMembershipTable(db *sql.DB) error {
return err return err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "add forgotten column", Version: "add forgotten column",
Up: deltas.UpAddForgottenColumn, Up: deltas.UpAddForgottenColumn,
}) })

View file

@ -57,12 +57,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count) err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
if err == nil { if err == nil {
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
{
Version: "state blocks refactor", Version: "state blocks refactor",
Up: deltas.UpStateBlocksRefactor, Up: deltas.UpStateBlocksRefactor,
}, })
}...)
if err := m.Up(context.Background()); err != nil { if err := m.Up(context.Background()); err != nil {
return nil, err return nil, err
} }

View file

@ -65,12 +65,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count) err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
if err == nil { if err == nil {
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
{
Version: "state blocks refactor", Version: "state blocks refactor",
Up: deltas.UpStateBlocksRefactor, Up: deltas.UpStateBlocksRefactor,
}, })
}...)
if err := m.Up(context.Background()); err != nil { if err := m.Up(context.Background()); err != nil {
return nil, err return nil, err
} }

View file

@ -75,7 +75,7 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "fix sequences", Version: "fix sequences",
Up: deltas.UpFixSequences, Up: deltas.UpFixSequences,
}) })

View file

@ -77,7 +77,7 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "drop sent_by_token", Version: "drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn, Up: deltas.UpRemoveSendToDeviceSentColumn,
}) })

View file

@ -73,7 +73,7 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "fix sequences", Version: "fix sequences",
Up: deltas.UpFixSequences, Up: deltas.UpFixSequences,
}) })

View file

@ -78,7 +78,7 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "drop sent_by_token", Version: "drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn, Up: deltas.UpRemoveSendToDeviceSentColumn,
}) })

View file

@ -122,7 +122,7 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigration(sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
Version: "add last_seen_ts", Version: "add last_seen_ts",
Up: deltas.UpLastSeenTSIP, Up: deltas.UpLastSeenTSIP,
}) })

View file

@ -109,12 +109,10 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{ m.AddMigrations(sqlutil.Migration{
{
Version: "add last_seen_ts", Version: "add last_seen_ts",
Up: deltas.UpLastSeenTSIP, Up: deltas.UpLastSeenTSIP,
}, })
}...)
err = m.Up(context.Background()) err = m.Up(context.Background())
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL}, {&s.insertDeviceStmt, insertDeviceSQL},