From b449d539e2b4bbfa1e8178a45970e147c867e0e3 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 20 Sep 2022 12:16:51 +0200 Subject: [PATCH] PR comments; insertMigration test --- internal/sqlutil/migrate.go | 34 ++++++++++++++++++---------- internal/sqlutil/migrate_test.go | 38 +++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index a6fa1cf57..a66a75826 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -55,7 +55,7 @@ type Migrator struct { migrations []Migration knownMigrations map[string]struct{} mutex *sync.Mutex - provider execProvider + insertStmt *sql.Stmt } // NewMigrator creates a new DB migrator. @@ -65,7 +65,6 @@ func NewMigrator(db *sql.DB) *Migrator { migrations: []Migration{}, knownMigrations: make(map[string]struct{}), mutex: &sync.Mutex{}, - provider: db, } } @@ -89,9 +88,9 @@ func (m *Migrator) Up(ctx context.Context) error { if err != nil { return fmt.Errorf("unable to create/get migrations: %w", err) } - + // ensure we close the insert statement, as it's not needed anymore + defer m.close() return WithTransaction(m.db, func(txn *sql.Tx) error { - m.provider = txn for i := range m.migrations { migration := m.migrations[i] // Skip migration if it was already executed @@ -103,7 +102,7 @@ func (m *Migrator) Up(ctx context.Context) error { if err = migration.Up(ctx, txn); err != nil { return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) } - if err = m.insertMigration(ctx, migration.Version); err != nil { + if err = m.insertMigration(ctx, txn, migration.Version); err != nil { return fmt.Errorf("unable to insert executed migrations: %w", err) } } @@ -111,12 +110,16 @@ func (m *Migrator) Up(ctx context.Context) error { }) } -type execProvider interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) -} - -func (m *Migrator) insertMigration(ctx context.Context, migrationName string) error { - _, err := m.provider.ExecContext(ctx, insertVersionSQL, +func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error { + if m.insertStmt == nil { + stmt, err := m.db.Prepare(insertVersionSQL) + if err != nil { + return fmt.Errorf("unable to prepare insert statement: %w", err) + } + m.insertStmt = stmt + } + stmt := TxStmtContext(ctx, txn, m.insertStmt) + _, err := stmt.ExecContext(ctx, migrationName, time.Now().Format(time.RFC3339), internal.VersionString(), @@ -153,6 +156,7 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, // This should only be used when manually inserting migrations. func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error { m := NewMigrator(db) + defer m.close() existingMigrations, err := m.ExecutedMigrations(ctx) if err != nil { return err @@ -160,5 +164,11 @@ func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) erro if _, ok := existingMigrations[migrationName]; ok { return nil } - return m.insertMigration(ctx, migrationName) + return m.insertMigration(ctx, nil, migrationName) +} + +func (m *Migrator) close() { + if m.insertStmt != nil { + internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement") + } } diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go index d8bcae196..a1088a712 100644 --- a/internal/sqlutil/migrate_test.go +++ b/internal/sqlutil/migrate_test.go @@ -7,9 +7,10 @@ import ( "reflect" "testing" + _ "github.com/mattn/go-sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/test" - _ "github.com/mattn/go-sqlite3" ) var dummyMigrations = []sqlutil.Migration{ @@ -81,11 +82,12 @@ func Test_migrations_Up(t *testing.T) { } ctx := context.Background() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - conStr, close := test.PrepareDBConnectionString(t, dbType) - defer close() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { driverName := "sqlite3" if dbType == test.DBTypePostgres { driverName = "postgres" @@ -107,6 +109,26 @@ func Test_migrations_Up(t *testing.T) { t.Errorf("expected: %+v, got %v", tt.wantResult, result) } }) - }) - } + } + }) +} + +func Test_insertMigration(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + driverName := "sqlite3" + if dbType == test.DBTypePostgres { + driverName = "postgres" + } + + db, err := sql.Open(driverName, conStr) + if err != nil { + t.Errorf("unable to open database: %v", err) + } + + if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil { + t.Fatalf("unable to insert migration: %s", err) + } + }) }