From 91c6cde48b2a500034293b074fd71e5fd5f4f1c1 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Thu, 16 Jun 2022 10:00:21 +0200 Subject: [PATCH] Update tests --- internal/sqlutil/migrate_test.go | 70 ++++++++++++++++---------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go index 3b00e9e6c..eb7571afc 100644 --- a/internal/sqlutil/migrate_test.go +++ b/internal/sqlutil/migrate_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test" _ "github.com/mattn/go-sqlite3" ) @@ -55,23 +56,17 @@ var failMigration = sqlutil.Migration{ } func Test_migrations_Up(t *testing.T) { - withFail := make([]sqlutil.Migration, len(dummyMigrations)) - copy(withFail, dummyMigrations) - withFail = append(withFail, failMigration) + withFail := append(dummyMigrations, failMigration) tests := []struct { - name string - connectionString string - ctx context.Context - migrations []sqlutil.Migration - wantResult map[string]bool - wantErr bool + name string + migrations []sqlutil.Migration + wantResult map[string]bool + wantErr bool }{ { - name: "dummy migration", - connectionString: "file::memory:", - migrations: dummyMigrations, - ctx: context.Background(), + name: "dummy migration", + migrations: dummyMigrations, wantResult: map[string]bool{ "init": true, "v2": true, @@ -79,32 +74,39 @@ func Test_migrations_Up(t *testing.T) { }, }, { - name: "with fail", - connectionString: "file::memory:", - migrations: withFail, - ctx: context.Background(), - wantErr: true, + name: "with fail", + migrations: withFail, + wantErr: true, }, } + ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db, err := sql.Open("sqlite3", tt.connectionString) - if err != nil { - t.Errorf("unable to open database: %v", err) - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(tt.migrations...) - if err = m.Up(tt.ctx); (err != nil) != tt.wantErr { - t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr) - } - result, err := m.ExecutedMigrations(tt.ctx) - if err != nil { - t.Errorf("unable to get executed migrations: %v", err) - } - if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) { - t.Errorf("expected: %+v, got %v", tt.wantResult, result) - } + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + driverName := "sqlite3" + if dbType == test.DBTypePostgres { + driverName = "postgres" + } + db, err := sql.Open(driverName, conStr) + if err != nil { + t.Errorf("unable to open database: %v", err) + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(tt.migrations...) + if err = m.Up(ctx); (err != nil) != tt.wantErr { + t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr) + } + result, err := m.ExecutedMigrations(ctx) + if err != nil { + t.Errorf("unable to get executed migrations: %v", err) + } + if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) { + t.Errorf("expected: %+v, got %v", tt.wantResult, result) + } + }) }) } }