diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index 7e37f2675..a6fa1cf57 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -55,6 +55,7 @@ type Migrator struct { migrations []Migration knownMigrations map[string]struct{} mutex *sync.Mutex + provider execProvider } // NewMigrator creates a new DB migrator. @@ -64,6 +65,7 @@ func NewMigrator(db *sql.DB) *Migrator { migrations: []Migration{}, knownMigrations: make(map[string]struct{}), mutex: &sync.Mutex{}, + provider: db, } } @@ -89,6 +91,7 @@ func (m *Migrator) Up(ctx context.Context) error { } return WithTransaction(m.db, func(txn *sql.Tx) error { + m.provider = txn for i := range m.migrations { migration := m.migrations[i] // Skip migration if it was already executed @@ -100,7 +103,7 @@ func (m *Migrator) Up(ctx context.Context) error { if err = migration.Up(ctx, txn); err != nil { return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) } - if err = m.insertMigration(ctx, txn, migration.Version); err != nil { + if err = m.insertMigration(ctx, migration.Version); err != nil { return fmt.Errorf("unable to insert executed migrations: %w", err) } } @@ -112,8 +115,8 @@ type execProvider interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } -func (m *Migrator) insertMigration(ctx context.Context, provider execProvider, migrationName string) error { - _, err := provider.ExecContext(ctx, insertVersionSQL, +func (m *Migrator) insertMigration(ctx context.Context, migrationName string) error { + _, err := m.provider.ExecContext(ctx, insertVersionSQL, migrationName, time.Now().Format(time.RFC3339), internal.VersionString(), @@ -157,5 +160,5 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro if _, ok := existingMigrations[migrationName]; ok { return nil } - return m.insertMigration(ctx, nil, migrationName) + return m.insertMigration(ctx, migrationName) }