Add new db migration

This commit is contained in:
Till Faelligen 2022-03-09 13:01:06 +01:00
parent 979738b2da
commit 050290ee64
2 changed files with 197 additions and 102 deletions

View file

@ -1,130 +1,115 @@
package sqlutil
import (
"context"
"database/sql"
"fmt"
"runtime"
"sort"
"github.com/matrix-org/dendrite/setup/config"
"github.com/pressly/goose"
"sync"
"time"
)
type Migrations struct {
registeredGoMigrations map[int64]*goose.Migration
// Migration defines a migration to be run.
type Migration struct {
// Version is a simple name description/name of this migration
Version string
// Up defines function to execute
Up func(ctx context.Context, txn *sql.Tx) error
// Down defines function to execute (not implemented yet)
Down func(ctx context.Context, txn *sql.Tx) error
}
func NewMigrations() *Migrations {
return &Migrations{
registeredGoMigrations: make(map[int64]*goose.Migration),
// Migrator the structure used by migrations
type Migrator struct {
db *sql.DB
migrations []Migration
knownMigrations map[string]bool
mutex *sync.Mutex
}
// NewMigrator creates a new DB migrator
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
knownMigrations: make(map[string]bool),
mutex: &sync.Mutex{},
}
}
// Copy-pasted from goose directly to store migrations into a map we control
// AddMigration adds a migration.
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
m.AddNamedMigration(filename, up, down)
}
// AddNamedMigration : Add a named migration.
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := goose.NumericComponent(filename)
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := m.registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
// AddMigration adds new migrations to the list.
// De-duplicates migrations by their version
func (m *Migrator) AddMigration(migration Migration) {
m.mutex.Lock()
defer m.mutex.Unlock()
if !m.knownMigrations[migration.Version] {
m.migrations = append(m.migrations, migration)
m.knownMigrations[migration.Version] = true
}
m.registeredGoMigrations[v] = migration
}
// RunDeltas up to the latest version.
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
maxVer := goose.MaxVersion
minVer := int64(0)
migrations, err := m.collect(minVer, maxVer)
// AddMigrations is a convenience method to add migrations
func (m *Migrator) AddMigrations(migrations ...Migration) {
for _, mig := range migrations {
m.AddMigration(mig)
}
}
// Up executes all migrations
func (m *Migrator) Up(ctx context.Context) error {
var err error
// ensure there is a table for known migrations
executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err)
return fmt.Errorf("unable to create/get migrations: %w", err)
}
if props.ConnectionString.IsPostgres() {
if err = goose.SetDialect("postgres"); err != nil {
return err
}
} else if props.ConnectionString.IsSQLite() {
if err = goose.SetDialect("sqlite3"); err != nil {
return err
}
} else {
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
}
for {
current, err := goose.EnsureDBVersion(db)
if err != nil {
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
}
next, err := migrations.Next(current)
txn, err := m.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("unable to begin transaction: %w", err)
}
defer func() {
if err != nil {
if err == goose.ErrNoNextVersion {
return nil
_ = txn.Rollback()
}
}()
for i := range m.migrations {
migration := m.migrations[i]
if !executedMigrations[migration.Version] {
err = migration.Up(ctx, txn)
if err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC())
if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
}
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
}
if err = next.Up(db); err != nil {
return fmt.Errorf("runDeltas: Failed run migration: %w", err)
}
}
if err = txn.Commit(); err != nil {
return fmt.Errorf("unable to commit transaction: %w", err)
}
return nil
}
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
var migrations goose.Migrations
// Go migrations registered via goose.AddMigration().
for _, migration := range m.registeredGoMigrations {
v, err := goose.NumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
// executedMigrations returns a map with already executed migrations
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
result := make(map[string]bool)
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );")
if err != nil {
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
}
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations")
if err != nil {
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
}
var version string
for rows.Next() {
if err := rows.Scan(&version); err != nil {
return nil, fmt.Errorf("unable to scan version: %w", err)
}
result[version] = true
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
return result, rows.Err()
}

View file

@ -0,0 +1,110 @@
package sqlutil_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
_ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
{
Version: "init",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
return err
},
},
{
Version: "v2",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "v2", // duplicate, this migration will be skipped
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "multiple execs",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
if err != nil {
return err
}
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
return err
},
},
}
var failMigration = sqlutil.Migration{
Version: "iFail",
Up: func(ctx context.Context, txn *sql.Tx) error {
return fmt.Errorf("iFail")
},
Down: nil,
}
func Test_migrations_Up(t *testing.T) {
withFail := make([]sqlutil.Migration, len(dummyMigrations))
copy(withFail, dummyMigrations)
withFail = append(withFail, failMigration)
tests := []struct {
name string
connectionString string
ctx context.Context
migrations []sqlutil.Migration
wantResult map[string]bool
wantErr bool
}{
{
name: "dummy migration",
connectionString: "file::memory:",
migrations: dummyMigrations,
ctx: context.Background(),
wantResult: map[string]bool{
"init": true,
"v2": true,
"multiple execs": true,
},
},
{
name: "with fail",
connectionString: "file::memory:",
migrations: withFail,
ctx: context.Background(),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := sql.Open("sqlite3", tt.connectionString)
if err != nil {
t.Errorf("unable to open database: %w", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err := m.Up(tt.ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(tt.ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %w", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
}
}