Update comments, outdent if

This commit is contained in:
Till Faelligen 2022-07-07 12:43:11 +02:00
parent 2dde6109a2
commit 26024af1f9
2 changed files with 44 additions and 42 deletions

View file

@ -22,64 +22,64 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/sirupsen/logrus"
) )
const createDBMigrationsSQL = "" + const createDBMigrationsSQL = "" +
"CREATE TABLE IF NOT EXISTS db_migrations (" + "CREATE TABLE IF NOT EXISTS db_migrations (" +
" version TEXT PRIMARY KEY," + " version TEXT PRIMARY KEY NOT NULL," +
" time TEXT," + " time TEXT NOT NULL," +
" dendrite_version TEXT" + " dendrite_version TEXT NOT NULL" +
");" ");"
const insertVersionSQL = "" + const insertVersionSQL = "" +
"INSERT INTO db_migrations (version, time, dendrite_version)" + "INSERT INTO db_migrations (version, time, dendrite_version)" +
" VALUES ($1, $2, $3) " + " VALUES ($1, $2, $3)"
" ON CONFLICT(version) DO UPDATE SET dendrite_version = $4, time = $5"
const selectDBMigrationsSQL = "SELECT version FROM db_migrations" const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
// Migration defines a migration to be run. // Migration defines a migration to be run.
type Migration struct { type Migration struct {
// Version is a simple name description/name of this migration // Version is a simple description/name of this migration.
Version string Version string
// Up defines function to execute // Up defines the function to execute for an upgrade.
Up func(ctx context.Context, txn *sql.Tx) error Up func(ctx context.Context, txn *sql.Tx) error
// Down defines function to execute (not implemented yet) // Down defines the function to execute for a downgrade (not implemented yet).
Down func(ctx context.Context, txn *sql.Tx) error Down func(ctx context.Context, txn *sql.Tx) error
} }
// Migrator the structure used by migrations // Migrator
type Migrator struct { type Migrator struct {
db *sql.DB db *sql.DB
migrations []Migration migrations []Migration
knownMigrations map[string]bool knownMigrations map[string]struct{}
mutex *sync.Mutex mutex *sync.Mutex
} }
// NewMigrator creates a new DB migrator // NewMigrator creates a new DB migrator.
func NewMigrator(db *sql.DB) *Migrator { func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{ return &Migrator{
db: db, db: db,
migrations: []Migration{}, migrations: []Migration{},
knownMigrations: make(map[string]bool), knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
} }
} }
// AddMigrations adds new migrations to the list. // AddMigrations appends migrations to the list of migrations. Migrations are executed
// De-duplicates migrations by their version // in the order they are added to the list. De-duplicates migrations using their Version field.
func (m *Migrator) AddMigrations(migrations ...Migration) { func (m *Migrator) AddMigrations(migrations ...Migration) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
for _, mig := range migrations { for _, mig := range migrations {
if !m.knownMigrations[mig.Version] { if _, ok := m.knownMigrations[mig.Version]; !ok {
m.migrations = append(m.migrations, mig) m.migrations = append(m.migrations, mig)
m.knownMigrations[mig.Version] = true m.knownMigrations[mig.Version] = struct{}{}
} }
} }
} }
// Up executes all migrations // Up executes all migrations in order they were added.
func (m *Migrator) Up(ctx context.Context) error { func (m *Migrator) Up(ctx context.Context) error {
var ( var (
err error err error
@ -95,30 +95,32 @@ func (m *Migrator) Up(ctx context.Context) error {
for i := range m.migrations { for i := range m.migrations {
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
migration := m.migrations[i] migration := m.migrations[i]
if !executedMigrations[migration.Version] { logrus.Debugf("Executing database migration '%s'", migration.Version)
err = migration.Up(ctx, txn) // Skip migration if it was already executed
if err != nil { if _, ok := executedMigrations[migration.Version]; ok {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) continue
} }
_, err = txn.ExecContext(ctx, insertVersionSQL, err = migration.Up(ctx, txn)
migration.Version, if err != nil {
now, return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
dendriteVersion, }
dendriteVersion, _, err = txn.ExecContext(ctx, insertVersionSQL,
now, migration.Version,
) now,
if err != nil { dendriteVersion,
return fmt.Errorf("unable to insert executed migrations: %w", err) )
} if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
} }
} }
return nil return nil
}) })
} }
// ExecutedMigrations returns a map with already executed migrations // ExecutedMigrations returns a map with already executed migrations in addition to creating the
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) { // migrations table, if it doesn't exist.
result := make(map[string]bool) func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
result := make(map[string]struct{})
_, err := m.db.ExecContext(ctx, createDBMigrationsSQL) _, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create db_migrations: %w", err) return nil, fmt.Errorf("unable to create db_migrations: %w", err)
@ -127,13 +129,13 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, err
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query db_migrations: %w", err) return nil, fmt.Errorf("unable to query db_migrations: %w", err)
} }
defer rows.Close() // nolint: errcheck defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
var version string var version string
for rows.Next() { for rows.Next() {
if err = rows.Scan(&version); err != nil { if err = rows.Scan(&version); err != nil {
return nil, fmt.Errorf("unable to scan version: %w", err) return nil, fmt.Errorf("unable to scan version: %w", err)
} }
result[version] = true result[version] = struct{}{}
} }
return result, rows.Err() return result, rows.Err()

View file

@ -61,16 +61,16 @@ func Test_migrations_Up(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
migrations []sqlutil.Migration migrations []sqlutil.Migration
wantResult map[string]bool wantResult map[string]struct{}
wantErr bool wantErr bool
}{ }{
{ {
name: "dummy migration", name: "dummy migration",
migrations: dummyMigrations, migrations: dummyMigrations,
wantResult: map[string]bool{ wantResult: map[string]struct{}{
"init": true, "init": {},
"v2": true, "v2": {},
"multiple execs": true, "multiple execs": {},
}, },
}, },
{ {