From 050290ee6473de277fff558f3dde86ef56694f79 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Wed, 9 Mar 2022 13:01:06 +0100 Subject: [PATCH] Add new db migration --- internal/sqlutil/migrate.go | 189 ++++++++++++++----------------- internal/sqlutil/migrate_test.go | 110 ++++++++++++++++++ 2 files changed, 197 insertions(+), 102 deletions(-) create mode 100644 internal/sqlutil/migrate_test.go diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index 7518df3c8..3109b8f49 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -1,130 +1,115 @@ package sqlutil import ( + "context" "database/sql" "fmt" - "runtime" - "sort" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/pressly/goose" + "sync" + "time" ) -type Migrations struct { - registeredGoMigrations map[int64]*goose.Migration +// Migration defines a migration to be run. +type Migration struct { + // Version is a simple name description/name of this migration + Version string + // Up defines function to execute + Up func(ctx context.Context, txn *sql.Tx) error + // Down defines function to execute (not implemented yet) + Down func(ctx context.Context, txn *sql.Tx) error } -func NewMigrations() *Migrations { - return &Migrations{ - registeredGoMigrations: make(map[int64]*goose.Migration), +// Migrator the structure used by migrations +type Migrator struct { + db *sql.DB + migrations []Migration + knownMigrations map[string]bool + mutex *sync.Mutex +} + +// NewMigrator creates a new DB migrator +func NewMigrator(db *sql.DB) *Migrator { + return &Migrator{ + db: db, + migrations: []Migration{}, + knownMigrations: make(map[string]bool), + mutex: &sync.Mutex{}, } } -// Copy-pasted from goose directly to store migrations into a map we control - -// AddMigration adds a migration. -func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { - _, filename, _, _ := runtime.Caller(1) - m.AddNamedMigration(filename, up, down) -} - -// AddNamedMigration : Add a named migration. -func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { - v, _ := goose.NumericComponent(filename) - migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename} - - if existing, ok := m.registeredGoMigrations[v]; ok { - panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) +// AddMigration adds new migrations to the list. +// De-duplicates migrations by their version +func (m *Migrator) AddMigration(migration Migration) { + m.mutex.Lock() + defer m.mutex.Unlock() + if !m.knownMigrations[migration.Version] { + m.migrations = append(m.migrations, migration) + m.knownMigrations[migration.Version] = true } - - m.registeredGoMigrations[v] = migration } -// RunDeltas up to the latest version. -func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error { - maxVer := goose.MaxVersion - minVer := int64(0) - migrations, err := m.collect(minVer, maxVer) +// AddMigrations is a convenience method to add migrations +func (m *Migrator) AddMigrations(migrations ...Migration) { + for _, mig := range migrations { + m.AddMigration(mig) + } +} + +// Up executes all migrations +func (m *Migrator) Up(ctx context.Context) error { + var err error + // ensure there is a table for known migrations + executedMigrations, err := m.ExecutedMigrations(ctx) if err != nil { - return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err) + return fmt.Errorf("unable to create/get migrations: %w", err) } - if props.ConnectionString.IsPostgres() { - if err = goose.SetDialect("postgres"); err != nil { - return err - } - } else if props.ConnectionString.IsSQLite() { - if err = goose.SetDialect("sqlite3"); err != nil { - return err - } - } else { - return fmt.Errorf("unknown connection string: %s", props.ConnectionString) - } - for { - current, err := goose.EnsureDBVersion(db) - if err != nil { - return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err) - } - next, err := migrations.Next(current) + txn, err := m.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("unable to begin transaction: %w", err) + } + defer func() { if err != nil { - if err == goose.ErrNoNextVersion { - return nil + _ = txn.Rollback() + } + }() + + for i := range m.migrations { + migration := m.migrations[i] + if !executedMigrations[migration.Version] { + err = migration.Up(ctx, txn) + if err != nil { + return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) + } + _, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC()) + if err != nil { + return fmt.Errorf("unable to insert executed migrations: %w", err) } - - return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err) - } - - if err = next.Up(db); err != nil { - return fmt.Errorf("runDeltas: Failed run migration: %w", err) } } + if err = txn.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + return nil } -func (m *Migrations) collect(current, target int64) (goose.Migrations, error) { - var migrations goose.Migrations - - // Go migrations registered via goose.AddMigration(). - for _, migration := range m.registeredGoMigrations { - v, err := goose.NumericComponent(migration.Source) - if err != nil { - return nil, err - } - if versionFilter(v, current, target) { - migrations = append(migrations, migration) +// executedMigrations returns a map with already executed migrations +func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) { + result := make(map[string]bool) + _, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );") + if err != nil { + return nil, fmt.Errorf("unable to create db_migrations: %w", err) + } + rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations") + if err != nil { + return nil, fmt.Errorf("unable to query db_migrations: %w", err) + } + var version string + for rows.Next() { + if err := rows.Scan(&version); err != nil { + return nil, fmt.Errorf("unable to scan version: %w", err) } + result[version] = true } - migrations = sortAndConnectMigrations(migrations) - - return migrations, nil -} - -func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations { - sort.Sort(migrations) - - // now that we're sorted in the appropriate direction, - // populate next and previous for each migration - for i, m := range migrations { - prev := int64(-1) - if i > 0 { - prev = migrations[i-1].Version - migrations[i-1].Next = m.Version - } - migrations[i].Previous = prev - } - - return migrations -} - -func versionFilter(v, current, target int64) bool { - - if target > current { - return v > current && v <= target - } - - if target < current { - return v <= current && v > target - } - - return false + return result, rows.Err() } diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go new file mode 100644 index 000000000..ac54687a4 --- /dev/null +++ b/internal/sqlutil/migrate_test.go @@ -0,0 +1,110 @@ +package sqlutil_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + _ "github.com/mattn/go-sqlite3" +) + +var dummyMigrations = []sqlutil.Migration{ + { + Version: "init", + Up: func(ctx context.Context, txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );") + return err + }, + }, + { + Version: "v2", + Up: func(ctx context.Context, txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;") + return err + }, + }, + { + Version: "v2", // duplicate, this migration will be skipped + Up: func(ctx context.Context, txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;") + return err + }, + }, + { + Version: "multiple execs", + Up: func(ctx context.Context, txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;") + if err != nil { + return err + } + _, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;") + return err + }, + }, +} + +var failMigration = sqlutil.Migration{ + Version: "iFail", + Up: func(ctx context.Context, txn *sql.Tx) error { + return fmt.Errorf("iFail") + }, + Down: nil, +} + +func Test_migrations_Up(t *testing.T) { + withFail := make([]sqlutil.Migration, len(dummyMigrations)) + copy(withFail, dummyMigrations) + withFail = append(withFail, failMigration) + + tests := []struct { + name string + connectionString string + ctx context.Context + migrations []sqlutil.Migration + wantResult map[string]bool + wantErr bool + }{ + { + name: "dummy migration", + connectionString: "file::memory:", + migrations: dummyMigrations, + ctx: context.Background(), + wantResult: map[string]bool{ + "init": true, + "v2": true, + "multiple execs": true, + }, + }, + { + name: "with fail", + connectionString: "file::memory:", + migrations: withFail, + ctx: context.Background(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := sql.Open("sqlite3", tt.connectionString) + if err != nil { + t.Errorf("unable to open database: %w", err) + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(tt.migrations...) + if err := m.Up(tt.ctx); (err != nil) != tt.wantErr { + t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr) + } + result, err := m.ExecutedMigrations(tt.ctx) + if err != nil { + t.Errorf("unable to get executed migrations: %w", err) + } + if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) { + t.Errorf("expected: %+v, got %v", tt.wantResult, result) + } + }) + } +}