PR comments; insertMigration test

This commit is contained in:
Till Faelligen 2022-09-20 12:16:51 +02:00
parent 266678cceb
commit b449d539e2
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
2 changed files with 52 additions and 20 deletions

View file

@ -55,7 +55,7 @@ type Migrator struct {
migrations []Migration
knownMigrations map[string]struct{}
mutex *sync.Mutex
provider execProvider
insertStmt *sql.Stmt
}
// NewMigrator creates a new DB migrator.
@ -65,7 +65,6 @@ func NewMigrator(db *sql.DB) *Migrator {
migrations: []Migration{},
knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{},
provider: db,
}
}
@ -89,9 +88,9 @@ func (m *Migrator) Up(ctx context.Context) error {
if err != nil {
return fmt.Errorf("unable to create/get migrations: %w", err)
}
// ensure we close the insert statement, as it's not needed anymore
defer m.close()
return WithTransaction(m.db, func(txn *sql.Tx) error {
m.provider = txn
for i := range m.migrations {
migration := m.migrations[i]
// Skip migration if it was already executed
@ -103,7 +102,7 @@ func (m *Migrator) Up(ctx context.Context) error {
if err = migration.Up(ctx, txn); err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
if err = m.insertMigration(ctx, migration.Version); err != nil {
if err = m.insertMigration(ctx, txn, migration.Version); err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
}
}
@ -111,12 +110,16 @@ func (m *Migrator) Up(ctx context.Context) error {
})
}
type execProvider interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
func (m *Migrator) insertMigration(ctx context.Context, migrationName string) error {
_, err := m.provider.ExecContext(ctx, insertVersionSQL,
func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error {
if m.insertStmt == nil {
stmt, err := m.db.Prepare(insertVersionSQL)
if err != nil {
return fmt.Errorf("unable to prepare insert statement: %w", err)
}
m.insertStmt = stmt
}
stmt := TxStmtContext(ctx, txn, m.insertStmt)
_, err := stmt.ExecContext(ctx,
migrationName,
time.Now().Format(time.RFC3339),
internal.VersionString(),
@ -153,6 +156,7 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{},
// This should only be used when manually inserting migrations.
func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error {
m := NewMigrator(db)
defer m.close()
existingMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
return err
@ -160,5 +164,11 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro
if _, ok := existingMigrations[migrationName]; ok {
return nil
}
return m.insertMigration(ctx, migrationName)
return m.insertMigration(ctx, nil, migrationName)
}
func (m *Migrator) close() {
if m.insertStmt != nil {
internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement")
}
}

View file

@ -7,9 +7,10 @@ import (
"reflect"
"testing"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
@ -81,11 +82,12 @@ func Test_migrations_Up(t *testing.T) {
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
@ -107,6 +109,26 @@ func Test_migrations_Up(t *testing.T) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
})
}
}
})
}
func Test_insertMigration(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil {
t.Fatalf("unable to insert migration: %s", err)
}
})
}