From ffe666fef9c72c210abd6002a9a619e401407720 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Tue, 15 Mar 2022 13:48:38 +0100 Subject: [PATCH] Remove AddMigration Use WithTransaction Add Dendrite version to table --- federationapi/storage/postgres/storage.go | 2 +- federationapi/storage/sqlite3/storage.go | 2 +- internal/sqlutil/migrate.go | 89 +++++++++++-------- .../storage/postgres/key_changes_table.go | 2 +- .../storage/sqlite3/key_changes_table.go | 2 +- .../storage/postgres/membership_table.go | 2 +- roomserver/storage/postgres/storage.go | 10 +-- roomserver/storage/sqlite3/storage.go | 10 +-- syncapi/storage/postgres/receipt_table.go | 2 +- .../storage/postgres/send_to_device_table.go | 2 +- syncapi/storage/sqlite3/receipt_table.go | 2 +- .../storage/sqlite3/send_to_device_table.go | 2 +- userapi/storage/postgres/devices_table.go | 2 +- userapi/storage/sqlite3/devices_table.go | 10 +-- 14 files changed, 72 insertions(+), 67 deletions(-) diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index e1683e966..c530fa5c3 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -85,7 +85,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC return nil, err } m := sqlutil.NewMigrator(d.db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index a6695140b..83fd61901 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -84,7 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC return nil, err } m := sqlutil.NewMigrator(d.db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index 62d0e4272..d2e23155d 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -20,8 +20,24 @@ import ( "fmt" "sync" "time" + + "github.com/matrix-org/dendrite/internal" ) +const createDBMigrationsSQL = "" + + "CREATE TABLE IF NOT EXISTS db_migrations (" + + " version TEXT PRIMARY KEY," + + " time TEXT," + + " dendrite_version TEXT" + + ");" + +const insertVersionSQL = "" + + "INSERT INTO db_migrations (version, time, dendrite_version)" + + " VALUES ($1, $2, $3) " + + " ON CONFLICT(version) DO UPDATE SET dendrite_version = $4, time = $5" + +const selectDBMigrationsSQL = "SELECT version FROM db_migrations" + // Migration defines a migration to be run. type Migration struct { // Version is a simple name description/name of this migration @@ -50,73 +66,68 @@ func NewMigrator(db *sql.DB) *Migrator { } } -// AddMigration adds new migrations to the list. +// AddMigrations adds new migrations to the list. // De-duplicates migrations by their version -func (m *Migrator) AddMigration(migration Migration) { +func (m *Migrator) AddMigrations(migrations ...Migration) { m.mutex.Lock() defer m.mutex.Unlock() - if !m.knownMigrations[migration.Version] { - m.migrations = append(m.migrations, migration) - m.knownMigrations[migration.Version] = true - } -} - -// AddMigrations is a convenience method to add migrations -func (m *Migrator) AddMigrations(migrations ...Migration) { for _, mig := range migrations { - m.AddMigration(mig) + if !m.knownMigrations[mig.Version] { + m.migrations = append(m.migrations, mig) + m.knownMigrations[mig.Version] = true + } } } // Up executes all migrations func (m *Migrator) Up(ctx context.Context) error { - var err error + var ( + err error + dendriteVersion = internal.VersionString() + ) // ensure there is a table for known migrations executedMigrations, err := m.ExecutedMigrations(ctx) if err != nil { return fmt.Errorf("unable to create/get migrations: %w", err) } - txn, err := m.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("unable to begin transaction: %w", err) - } - defer func() { - if err != nil { - _ = txn.Rollback() - } - }() - - for i := range m.migrations { - migration := m.migrations[i] - if !executedMigrations[migration.Version] { - err = migration.Up(ctx, txn) - if err != nil { - return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) - } - _, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC().Format(time.RFC3339)) - if err != nil { - return fmt.Errorf("unable to insert executed migrations: %w", err) + return WithTransaction(m.db, func(txn *sql.Tx) error { + for i := range m.migrations { + now := time.Now().UTC().Format(time.RFC3339) + migration := m.migrations[i] + if !executedMigrations[migration.Version] { + err = migration.Up(ctx, txn) + if err != nil { + return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) + } + _, err = txn.ExecContext(ctx, insertVersionSQL, + migration.Version, + now, + dendriteVersion, + dendriteVersion, + now, + ) + if err != nil { + return fmt.Errorf("unable to insert executed migrations: %w", err) + } } } - } - if err = txn.Commit(); err != nil { - return fmt.Errorf("unable to commit transaction: %w", err) - } - return nil + return nil + }) } // ExecutedMigrations returns a map with already executed migrations func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) { result := make(map[string]bool) - _, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );") + _, err := m.db.ExecContext(ctx, createDBMigrationsSQL) if err != nil { return nil, fmt.Errorf("unable to create db_migrations: %w", err) } - rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations") + rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL) if err != nil { return nil, fmt.Errorf("unable to query db_migrations: %w", err) } + defer rows.Close() // nolint: errcheck var version string for rows.Next() { if err := rows.Scan(&version); err != nil { diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 9a69657b7..b0eedf5a5 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -68,7 +68,7 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) if err == nil { m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "refactor key changes", Up: deltas.UpRefactorKeyChanges, }) diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 3a0e87b17..b783414f9 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -65,7 +65,7 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) if err == nil { m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "refactor key changes", Up: deltas.UpRefactorKeyChanges, }) diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 5d633bea1..7cee271ba 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -167,7 +167,7 @@ func createMembershipTable(db *sql.DB) error { return err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "add forgotten column", Up: deltas.UpAddForgottenColumn, }) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 235886c9d..703debbd4 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -57,12 +57,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count) if err == nil { m := sqlutil.NewMigrator(db) - m.AddMigrations([]sqlutil.Migration{ - { - Version: "state blocks refactor", - Up: deltas.UpStateBlocksRefactor, - }, - }...) + m.AddMigrations(sqlutil.Migration{ + Version: "state blocks refactor", + Up: deltas.UpStateBlocksRefactor, + }) if err := m.Up(context.Background()); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index fb5a34c92..dbf402af9 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -65,12 +65,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count) if err == nil { m := sqlutil.NewMigrator(db) - m.AddMigrations([]sqlutil.Migration{ - { - Version: "state blocks refactor", - Up: deltas.UpStateBlocksRefactor, - }, - }...) + m.AddMigrations(sqlutil.Migration{ + Version: "state blocks refactor", + Up: deltas.UpStateBlocksRefactor, + }) if err := m.Up(context.Background()); err != nil { return nil, err } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 0a03eb584..37a826ff7 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -75,7 +75,7 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "fix sequences", Up: deltas.UpFixSequences, }) diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 29d29019c..3709e54fc 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -77,7 +77,7 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "drop sent_by_token", Up: deltas.UpRemoveSendToDeviceSentColumn, }) diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 40b8e1b50..893bdbf48 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -73,7 +73,7 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "fix sequences", Up: deltas.UpFixSequences, }) diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index bebcc5b66..1b2eb60e3 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -78,7 +78,7 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "drop sent_by_token", Up: deltas.UpRemoveSendToDeviceSentColumn, }) diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index ffd92d3d4..3b2332214 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -122,7 +122,7 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigration(sqlutil.Migration{ + m.AddMigrations(sqlutil.Migration{ Version: "add last_seen_ts", Up: deltas.UpLastSeenTSIP, }) diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index e3210256b..f214d5161 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -109,12 +109,10 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations([]sqlutil.Migration{ - { - Version: "add last_seen_ts", - Up: deltas.UpLastSeenTSIP, - }, - }...) + m.AddMigrations(sqlutil.Migration{ + Version: "add last_seen_ts", + Up: deltas.UpLastSeenTSIP, + }) err = m.Up(context.Background()) return s, sqlutil.StatementList{ {&s.insertDeviceStmt, insertDeviceSQL},