From f9291f59a9f370ad83b727a83ccb0d61b8cfe36b Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 15 Sep 2022 08:30:08 +0200 Subject: [PATCH] Some more tweaks --- internal/sqlutil/migrate.go | 38 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index 830f82e75..7e37f2675 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -49,7 +49,7 @@ type Migration struct { Down func(ctx context.Context, txn *sql.Tx) error } -// Migrator +// Migrator contains fields required to run migrations. type Migrator struct { db *sql.DB migrations []Migration @@ -82,10 +82,6 @@ func (m *Migrator) AddMigrations(migrations ...Migration) { // Up executes all migrations in order they were added. func (m *Migrator) Up(ctx context.Context) error { - var ( - err error - dendriteVersion = internal.VersionString() - ) // ensure there is a table for known migrations executedMigrations, err := m.ExecutedMigrations(ctx) if err != nil { @@ -94,23 +90,17 @@ func (m *Migrator) Up(ctx context.Context) error { return WithTransaction(m.db, func(txn *sql.Tx) error { for i := range m.migrations { - now := time.Now().UTC().Format(time.RFC3339) migration := m.migrations[i] // Skip migration if it was already executed if _, ok := executedMigrations[migration.Version]; ok { continue } logrus.Debugf("Executing database migration '%s'", migration.Version) - err = migration.Up(ctx, txn) - if err != nil { + + if err = migration.Up(ctx, txn); err != nil { return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) } - _, err = txn.ExecContext(ctx, insertVersionSQL, - migration.Version, - now, - dendriteVersion, - ) - if err != nil { + if err = m.insertMigration(ctx, txn, migration.Version); err != nil { return fmt.Errorf("unable to insert executed migrations: %w", err) } } @@ -118,6 +108,19 @@ func (m *Migrator) Up(ctx context.Context) error { }) } +type execProvider interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +func (m *Migrator) insertMigration(ctx context.Context, provider execProvider, migrationName string) error { + _, err := provider.ExecContext(ctx, insertVersionSQL, + migrationName, + time.Now().Format(time.RFC3339), + internal.VersionString(), + ) + return err +} + // ExecutedMigrations returns a map with already executed migrations in addition to creating the // migrations table, if it doesn't exist. func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) { @@ -154,10 +157,5 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro if _, ok := existingMigrations[migrationName]; ok { return nil } - _, err = m.db.ExecContext(ctx, insertVersionSQL, - migrationName, - time.Now().Format(time.RFC3339), - internal.VersionString(), - ) - return err + return m.insertMigration(ctx, nil, migrationName) }