This commit is contained in:
Till Faelligen 2022-09-15 08:34:23 +02:00
parent f9291f59a9
commit 266678cceb
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E

View file

@ -55,6 +55,7 @@ type Migrator struct {
migrations []Migration migrations []Migration
knownMigrations map[string]struct{} knownMigrations map[string]struct{}
mutex *sync.Mutex mutex *sync.Mutex
provider execProvider
} }
// NewMigrator creates a new DB migrator. // NewMigrator creates a new DB migrator.
@ -64,6 +65,7 @@ func NewMigrator(db *sql.DB) *Migrator {
migrations: []Migration{}, migrations: []Migration{},
knownMigrations: make(map[string]struct{}), knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
provider: db,
} }
} }
@ -89,6 +91,7 @@ func (m *Migrator) Up(ctx context.Context) error {
} }
return WithTransaction(m.db, func(txn *sql.Tx) error { return WithTransaction(m.db, func(txn *sql.Tx) error {
m.provider = txn
for i := range m.migrations { for i := range m.migrations {
migration := m.migrations[i] migration := m.migrations[i]
// Skip migration if it was already executed // Skip migration if it was already executed
@ -100,7 +103,7 @@ func (m *Migrator) Up(ctx context.Context) error {
if err = migration.Up(ctx, txn); err != nil { if err = migration.Up(ctx, txn); err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
} }
if err = m.insertMigration(ctx, txn, migration.Version); err != nil { if err = m.insertMigration(ctx, migration.Version); err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err) return fmt.Errorf("unable to insert executed migrations: %w", err)
} }
} }
@ -112,8 +115,8 @@ type execProvider interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
} }
func (m *Migrator) insertMigration(ctx context.Context, provider execProvider, migrationName string) error { func (m *Migrator) insertMigration(ctx context.Context, migrationName string) error {
_, err := provider.ExecContext(ctx, insertVersionSQL, _, err := m.provider.ExecContext(ctx, insertVersionSQL,
migrationName, migrationName,
time.Now().Format(time.RFC3339), time.Now().Format(time.RFC3339),
internal.VersionString(), internal.VersionString(),
@ -157,5 +160,5 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro
if _, ok := existingMigrations[migrationName]; ok { if _, ok := existingMigrations[migrationName]; ok {
return nil return nil
} }
return m.insertMigration(ctx, nil, migrationName) return m.insertMigration(ctx, migrationName)
} }