mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 16:13:10 -06:00
PR comments; insertMigration test
This commit is contained in:
parent
266678cceb
commit
b449d539e2
|
|
@ -55,7 +55,7 @@ type Migrator struct {
|
|||
migrations []Migration
|
||||
knownMigrations map[string]struct{}
|
||||
mutex *sync.Mutex
|
||||
provider execProvider
|
||||
insertStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// NewMigrator creates a new DB migrator.
|
||||
|
|
@ -65,7 +65,6 @@ func NewMigrator(db *sql.DB) *Migrator {
|
|||
migrations: []Migration{},
|
||||
knownMigrations: make(map[string]struct{}),
|
||||
mutex: &sync.Mutex{},
|
||||
provider: db,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,9 +88,9 @@ func (m *Migrator) Up(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("unable to create/get migrations: %w", err)
|
||||
}
|
||||
|
||||
// ensure we close the insert statement, as it's not needed anymore
|
||||
defer m.close()
|
||||
return WithTransaction(m.db, func(txn *sql.Tx) error {
|
||||
m.provider = txn
|
||||
for i := range m.migrations {
|
||||
migration := m.migrations[i]
|
||||
// Skip migration if it was already executed
|
||||
|
|
@ -103,7 +102,7 @@ func (m *Migrator) Up(ctx context.Context) error {
|
|||
if err = migration.Up(ctx, txn); err != nil {
|
||||
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
||||
}
|
||||
if err = m.insertMigration(ctx, migration.Version); err != nil {
|
||||
if err = m.insertMigration(ctx, txn, migration.Version); err != nil {
|
||||
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -111,12 +110,16 @@ func (m *Migrator) Up(ctx context.Context) error {
|
|||
})
|
||||
}
|
||||
|
||||
type execProvider interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
func (m *Migrator) insertMigration(ctx context.Context, migrationName string) error {
|
||||
_, err := m.provider.ExecContext(ctx, insertVersionSQL,
|
||||
func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error {
|
||||
if m.insertStmt == nil {
|
||||
stmt, err := m.db.Prepare(insertVersionSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to prepare insert statement: %w", err)
|
||||
}
|
||||
m.insertStmt = stmt
|
||||
}
|
||||
stmt := TxStmtContext(ctx, txn, m.insertStmt)
|
||||
_, err := stmt.ExecContext(ctx,
|
||||
migrationName,
|
||||
time.Now().Format(time.RFC3339),
|
||||
internal.VersionString(),
|
||||
|
|
@ -153,6 +156,7 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{},
|
|||
// This should only be used when manually inserting migrations.
|
||||
func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error {
|
||||
m := NewMigrator(db)
|
||||
defer m.close()
|
||||
existingMigrations, err := m.ExecutedMigrations(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -160,5 +164,11 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro
|
|||
if _, ok := existingMigrations[migrationName]; ok {
|
||||
return nil
|
||||
}
|
||||
return m.insertMigration(ctx, migrationName)
|
||||
return m.insertMigration(ctx, nil, migrationName)
|
||||
}
|
||||
|
||||
func (m *Migrator) close() {
|
||||
if m.insertStmt != nil {
|
||||
internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var dummyMigrations = []sqlutil.Migration{
|
||||
|
|
@ -81,11 +82,12 @@ func Test_migrations_Up(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
conStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
defer close()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
conStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
defer close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
driverName := "sqlite3"
|
||||
if dbType == test.DBTypePostgres {
|
||||
driverName = "postgres"
|
||||
|
|
@ -107,6 +109,26 @@ func Test_migrations_Up(t *testing.T) {
|
|||
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_insertMigration(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
conStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
defer close()
|
||||
driverName := "sqlite3"
|
||||
if dbType == test.DBTypePostgres {
|
||||
driverName = "postgres"
|
||||
}
|
||||
|
||||
db, err := sql.Open(driverName, conStr)
|
||||
if err != nil {
|
||||
t.Errorf("unable to open database: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil {
|
||||
t.Fatalf("unable to insert migration: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue