mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-04 04:33:10 -06:00
Add new db migration
This commit is contained in:
parent
979738b2da
commit
050290ee64
|
|
@ -1,130 +1,115 @@
|
||||||
package sqlutil
|
package sqlutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"sync"
|
||||||
"sort"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/pressly/goose"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Migrations struct {
|
// Migration defines a migration to be run.
|
||||||
registeredGoMigrations map[int64]*goose.Migration
|
type Migration struct {
|
||||||
|
// Version is a simple name description/name of this migration
|
||||||
|
Version string
|
||||||
|
// Up defines function to execute
|
||||||
|
Up func(ctx context.Context, txn *sql.Tx) error
|
||||||
|
// Down defines function to execute (not implemented yet)
|
||||||
|
Down func(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMigrations() *Migrations {
|
// Migrator the structure used by migrations
|
||||||
return &Migrations{
|
type Migrator struct {
|
||||||
registeredGoMigrations: make(map[int64]*goose.Migration),
|
db *sql.DB
|
||||||
|
migrations []Migration
|
||||||
|
knownMigrations map[string]bool
|
||||||
|
mutex *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMigrator creates a new DB migrator
|
||||||
|
func NewMigrator(db *sql.DB) *Migrator {
|
||||||
|
return &Migrator{
|
||||||
|
db: db,
|
||||||
|
migrations: []Migration{},
|
||||||
|
knownMigrations: make(map[string]bool),
|
||||||
|
mutex: &sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy-pasted from goose directly to store migrations into a map we control
|
// AddMigration adds new migrations to the list.
|
||||||
|
// De-duplicates migrations by their version
|
||||||
// AddMigration adds a migration.
|
func (m *Migrator) AddMigration(migration Migration) {
|
||||||
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
m.mutex.Lock()
|
||||||
_, filename, _, _ := runtime.Caller(1)
|
defer m.mutex.Unlock()
|
||||||
m.AddNamedMigration(filename, up, down)
|
if !m.knownMigrations[migration.Version] {
|
||||||
}
|
m.migrations = append(m.migrations, migration)
|
||||||
|
m.knownMigrations[migration.Version] = true
|
||||||
// AddNamedMigration : Add a named migration.
|
|
||||||
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
|
||||||
v, _ := goose.NumericComponent(filename)
|
|
||||||
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
|
|
||||||
|
|
||||||
if existing, ok := m.registeredGoMigrations[v]; ok {
|
|
||||||
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.registeredGoMigrations[v] = migration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunDeltas up to the latest version.
|
// AddMigrations is a convenience method to add migrations
|
||||||
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
|
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||||
maxVer := goose.MaxVersion
|
for _, mig := range migrations {
|
||||||
minVer := int64(0)
|
m.AddMigration(mig)
|
||||||
migrations, err := m.collect(minVer, maxVer)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up executes all migrations
|
||||||
|
func (m *Migrator) Up(ctx context.Context) error {
|
||||||
|
var err error
|
||||||
|
// ensure there is a table for known migrations
|
||||||
|
executedMigrations, err := m.ExecutedMigrations(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err)
|
return fmt.Errorf("unable to create/get migrations: %w", err)
|
||||||
}
|
}
|
||||||
if props.ConnectionString.IsPostgres() {
|
|
||||||
if err = goose.SetDialect("postgres"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if props.ConnectionString.IsSQLite() {
|
|
||||||
if err = goose.SetDialect("sqlite3"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
current, err := goose.EnsureDBVersion(db)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
next, err := migrations.Next(current)
|
txn, err := m.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == goose.ErrNoNextVersion {
|
_ = txn.Rollback()
|
||||||
return nil
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := range m.migrations {
|
||||||
|
migration := m.migrations[i]
|
||||||
|
if !executedMigrations[migration.Version] {
|
||||||
|
err = migration.Up(ctx, txn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
||||||
|
}
|
||||||
|
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = next.Up(db); err != nil {
|
|
||||||
return fmt.Errorf("runDeltas: Failed run migration: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = txn.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("unable to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
|
// executedMigrations returns a map with already executed migrations
|
||||||
var migrations goose.Migrations
|
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
|
||||||
|
result := make(map[string]bool)
|
||||||
// Go migrations registered via goose.AddMigration().
|
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );")
|
||||||
for _, migration := range m.registeredGoMigrations {
|
if err != nil {
|
||||||
v, err := goose.NumericComponent(migration.Source)
|
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
|
||||||
if err != nil {
|
}
|
||||||
return nil, err
|
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations")
|
||||||
}
|
if err != nil {
|
||||||
if versionFilter(v, current, target) {
|
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
|
||||||
migrations = append(migrations, migration)
|
}
|
||||||
|
var version string
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&version); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to scan version: %w", err)
|
||||||
}
|
}
|
||||||
|
result[version] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
migrations = sortAndConnectMigrations(migrations)
|
return result, rows.Err()
|
||||||
|
|
||||||
return migrations, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
|
|
||||||
sort.Sort(migrations)
|
|
||||||
|
|
||||||
// now that we're sorted in the appropriate direction,
|
|
||||||
// populate next and previous for each migration
|
|
||||||
for i, m := range migrations {
|
|
||||||
prev := int64(-1)
|
|
||||||
if i > 0 {
|
|
||||||
prev = migrations[i-1].Version
|
|
||||||
migrations[i-1].Next = m.Version
|
|
||||||
}
|
|
||||||
migrations[i].Previous = prev
|
|
||||||
}
|
|
||||||
|
|
||||||
return migrations
|
|
||||||
}
|
|
||||||
|
|
||||||
func versionFilter(v, current, target int64) bool {
|
|
||||||
|
|
||||||
if target > current {
|
|
||||||
return v > current && v <= target
|
|
||||||
}
|
|
||||||
|
|
||||||
if target < current {
|
|
||||||
return v <= current && v > target
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
110
internal/sqlutil/migrate_test.go
Normal file
110
internal/sqlutil/migrate_test.go
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
package sqlutil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var dummyMigrations = []sqlutil.Migration{
|
||||||
|
{
|
||||||
|
Version: "init",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: "v2",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: "v2", // duplicate, this migration will be skipped
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: "multiple execs",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var failMigration = sqlutil.Migration{
|
||||||
|
Version: "iFail",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return fmt.Errorf("iFail")
|
||||||
|
},
|
||||||
|
Down: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_migrations_Up(t *testing.T) {
|
||||||
|
withFail := make([]sqlutil.Migration, len(dummyMigrations))
|
||||||
|
copy(withFail, dummyMigrations)
|
||||||
|
withFail = append(withFail, failMigration)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connectionString string
|
||||||
|
ctx context.Context
|
||||||
|
migrations []sqlutil.Migration
|
||||||
|
wantResult map[string]bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "dummy migration",
|
||||||
|
connectionString: "file::memory:",
|
||||||
|
migrations: dummyMigrations,
|
||||||
|
ctx: context.Background(),
|
||||||
|
wantResult: map[string]bool{
|
||||||
|
"init": true,
|
||||||
|
"v2": true,
|
||||||
|
"multiple execs": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with fail",
|
||||||
|
connectionString: "file::memory:",
|
||||||
|
migrations: withFail,
|
||||||
|
ctx: context.Background(),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", tt.connectionString)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to open database: %w", err)
|
||||||
|
}
|
||||||
|
m := sqlutil.NewMigrator(db)
|
||||||
|
m.AddMigrations(tt.migrations...)
|
||||||
|
if err := m.Up(tt.ctx); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
result, err := m.ExecutedMigrations(tt.ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to get executed migrations: %w", err)
|
||||||
|
}
|
||||||
|
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
|
||||||
|
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue