mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
116 lines
3.1 KiB
Go
116 lines
3.1 KiB
Go
package sqlutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Migration defines a migration to be run.
|
|
type Migration struct {
|
|
// Version is a simple name description/name of this migration
|
|
Version string
|
|
// Up defines function to execute
|
|
Up func(ctx context.Context, txn *sql.Tx) error
|
|
// Down defines function to execute (not implemented yet)
|
|
Down func(ctx context.Context, txn *sql.Tx) error
|
|
}
|
|
|
|
// Migrator the structure used by migrations
|
|
type Migrator struct {
|
|
db *sql.DB
|
|
migrations []Migration
|
|
knownMigrations map[string]bool
|
|
mutex *sync.Mutex
|
|
}
|
|
|
|
// NewMigrator creates a new DB migrator
|
|
func NewMigrator(db *sql.DB) *Migrator {
|
|
return &Migrator{
|
|
db: db,
|
|
migrations: []Migration{},
|
|
knownMigrations: make(map[string]bool),
|
|
mutex: &sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
// AddMigration adds new migrations to the list.
|
|
// De-duplicates migrations by their version
|
|
func (m *Migrator) AddMigration(migration Migration) {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
if !m.knownMigrations[migration.Version] {
|
|
m.migrations = append(m.migrations, migration)
|
|
m.knownMigrations[migration.Version] = true
|
|
}
|
|
}
|
|
|
|
// AddMigrations is a convenience method to add migrations
|
|
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
|
for _, mig := range migrations {
|
|
m.AddMigration(mig)
|
|
}
|
|
}
|
|
|
|
// Up executes all migrations
|
|
func (m *Migrator) Up(ctx context.Context) error {
|
|
var err error
|
|
// ensure there is a table for known migrations
|
|
executedMigrations, err := m.ExecutedMigrations(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create/get migrations: %w", err)
|
|
}
|
|
|
|
txn, err := m.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = txn.Rollback()
|
|
}
|
|
}()
|
|
|
|
for i := range m.migrations {
|
|
migration := m.migrations[i]
|
|
if !executedMigrations[migration.Version] {
|
|
err = migration.Up(ctx, txn)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
|
}
|
|
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
|
}
|
|
}
|
|
}
|
|
if err = txn.Commit(); err != nil {
|
|
return fmt.Errorf("unable to commit transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// executedMigrations returns a map with already executed migrations
|
|
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
|
|
result := make(map[string]bool)
|
|
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
|
|
}
|
|
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
|
|
}
|
|
var version string
|
|
for rows.Next() {
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, fmt.Errorf("unable to scan version: %w", err)
|
|
}
|
|
result[version] = true
|
|
}
|
|
|
|
return result, rows.Err()
|
|
}
|