mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
Add new db migration
This commit is contained in:
parent
979738b2da
commit
050290ee64
|
|
@ -1,130 +1,115 @@
|
|||
package sqlutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/pressly/goose"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Migrations struct {
|
||||
registeredGoMigrations map[int64]*goose.Migration
|
||||
// Migration defines a migration to be run.
|
||||
type Migration struct {
|
||||
// Version is a simple name description/name of this migration
|
||||
Version string
|
||||
// Up defines function to execute
|
||||
Up func(ctx context.Context, txn *sql.Tx) error
|
||||
// Down defines function to execute (not implemented yet)
|
||||
Down func(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
func NewMigrations() *Migrations {
|
||||
return &Migrations{
|
||||
registeredGoMigrations: make(map[int64]*goose.Migration),
|
||||
// Migrator the structure used by migrations
|
||||
type Migrator struct {
|
||||
db *sql.DB
|
||||
migrations []Migration
|
||||
knownMigrations map[string]bool
|
||||
mutex *sync.Mutex
|
||||
}
|
||||
|
||||
// NewMigrator creates a new DB migrator
|
||||
func NewMigrator(db *sql.DB) *Migrator {
|
||||
return &Migrator{
|
||||
db: db,
|
||||
migrations: []Migration{},
|
||||
knownMigrations: make(map[string]bool),
|
||||
mutex: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Copy-pasted from goose directly to store migrations into a map we control
|
||||
|
||||
// AddMigration adds a migration.
|
||||
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
m.AddNamedMigration(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigration : Add a named migration.
|
||||
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
||||
v, _ := goose.NumericComponent(filename)
|
||||
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
|
||||
|
||||
if existing, ok := m.registeredGoMigrations[v]; ok {
|
||||
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
|
||||
// AddMigration adds new migrations to the list.
|
||||
// De-duplicates migrations by their version
|
||||
func (m *Migrator) AddMigration(migration Migration) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if !m.knownMigrations[migration.Version] {
|
||||
m.migrations = append(m.migrations, migration)
|
||||
m.knownMigrations[migration.Version] = true
|
||||
}
|
||||
|
||||
m.registeredGoMigrations[v] = migration
|
||||
}
|
||||
|
||||
// RunDeltas up to the latest version.
|
||||
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
|
||||
maxVer := goose.MaxVersion
|
||||
minVer := int64(0)
|
||||
migrations, err := m.collect(minVer, maxVer)
|
||||
// AddMigrations is a convenience method to add migrations
|
||||
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||
for _, mig := range migrations {
|
||||
m.AddMigration(mig)
|
||||
}
|
||||
}
|
||||
|
||||
// Up executes all migrations
|
||||
func (m *Migrator) Up(ctx context.Context) error {
|
||||
var err error
|
||||
// ensure there is a table for known migrations
|
||||
executedMigrations, err := m.ExecutedMigrations(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err)
|
||||
return fmt.Errorf("unable to create/get migrations: %w", err)
|
||||
}
|
||||
if props.ConnectionString.IsPostgres() {
|
||||
if err = goose.SetDialect("postgres"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if props.ConnectionString.IsSQLite() {
|
||||
if err = goose.SetDialect("sqlite3"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
|
||||
}
|
||||
for {
|
||||
current, err := goose.EnsureDBVersion(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
|
||||
}
|
||||
|
||||
next, err := migrations.Next(current)
|
||||
txn, err := m.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if err == goose.ErrNoNextVersion {
|
||||
return nil
|
||||
_ = txn.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for i := range m.migrations {
|
||||
migration := m.migrations[i]
|
||||
if !executedMigrations[migration.Version] {
|
||||
err = migration.Up(ctx, txn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
||||
}
|
||||
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC())
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
|
||||
}
|
||||
|
||||
if err = next.Up(db); err != nil {
|
||||
return fmt.Errorf("runDeltas: Failed run migration: %w", err)
|
||||
}
|
||||
}
|
||||
if err = txn.Commit(); err != nil {
|
||||
return fmt.Errorf("unable to commit transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
|
||||
var migrations goose.Migrations
|
||||
|
||||
// Go migrations registered via goose.AddMigration().
|
||||
for _, migration := range m.registeredGoMigrations {
|
||||
v, err := goose.NumericComponent(migration.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if versionFilter(v, current, target) {
|
||||
migrations = append(migrations, migration)
|
||||
// executedMigrations returns a map with already executed migrations
|
||||
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
|
||||
result := make(map[string]bool)
|
||||
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
|
||||
}
|
||||
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
|
||||
}
|
||||
var version string
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan version: %w", err)
|
||||
}
|
||||
result[version] = true
|
||||
}
|
||||
|
||||
migrations = sortAndConnectMigrations(migrations)
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
|
||||
sort.Sort(migrations)
|
||||
|
||||
// now that we're sorted in the appropriate direction,
|
||||
// populate next and previous for each migration
|
||||
for i, m := range migrations {
|
||||
prev := int64(-1)
|
||||
if i > 0 {
|
||||
prev = migrations[i-1].Version
|
||||
migrations[i-1].Next = m.Version
|
||||
}
|
||||
migrations[i].Previous = prev
|
||||
}
|
||||
|
||||
return migrations
|
||||
}
|
||||
|
||||
func versionFilter(v, current, target int64) bool {
|
||||
|
||||
if target > current {
|
||||
return v > current && v <= target
|
||||
}
|
||||
|
||||
if target < current {
|
||||
return v <= current && v > target
|
||||
}
|
||||
|
||||
return false
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
|
|||
110
internal/sqlutil/migrate_test.go
Normal file
110
internal/sqlutil/migrate_test.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package sqlutil_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var dummyMigrations = []sqlutil.Migration{
|
||||
{
|
||||
Version: "init",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
Version: "v2",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
Version: "v2", // duplicate, this migration will be skipped
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
Version: "multiple execs",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var failMigration = sqlutil.Migration{
|
||||
Version: "iFail",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return fmt.Errorf("iFail")
|
||||
},
|
||||
Down: nil,
|
||||
}
|
||||
|
||||
func Test_migrations_Up(t *testing.T) {
|
||||
withFail := make([]sqlutil.Migration, len(dummyMigrations))
|
||||
copy(withFail, dummyMigrations)
|
||||
withFail = append(withFail, failMigration)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connectionString string
|
||||
ctx context.Context
|
||||
migrations []sqlutil.Migration
|
||||
wantResult map[string]bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "dummy migration",
|
||||
connectionString: "file::memory:",
|
||||
migrations: dummyMigrations,
|
||||
ctx: context.Background(),
|
||||
wantResult: map[string]bool{
|
||||
"init": true,
|
||||
"v2": true,
|
||||
"multiple execs": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with fail",
|
||||
connectionString: "file::memory:",
|
||||
migrations: withFail,
|
||||
ctx: context.Background(),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", tt.connectionString)
|
||||
if err != nil {
|
||||
t.Errorf("unable to open database: %w", err)
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(tt.migrations...)
|
||||
if err := m.Up(tt.ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
result, err := m.ExecutedMigrations(tt.ctx)
|
||||
if err != nil {
|
||||
t.Errorf("unable to get executed migrations: %w", err)
|
||||
}
|
||||
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
|
||||
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue