Update tests

This commit is contained in:
Till Faelligen 2022-06-16 10:00:21 +02:00
parent 99b657c7ea
commit 91c6cde48b

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -55,23 +56,17 @@ var failMigration = sqlutil.Migration{
} }
func Test_migrations_Up(t *testing.T) { func Test_migrations_Up(t *testing.T) {
withFail := make([]sqlutil.Migration, len(dummyMigrations)) withFail := append(dummyMigrations, failMigration)
copy(withFail, dummyMigrations)
withFail = append(withFail, failMigration)
tests := []struct { tests := []struct {
name string name string
connectionString string
ctx context.Context
migrations []sqlutil.Migration migrations []sqlutil.Migration
wantResult map[string]bool wantResult map[string]bool
wantErr bool wantErr bool
}{ }{
{ {
name: "dummy migration", name: "dummy migration",
connectionString: "file::memory:",
migrations: dummyMigrations, migrations: dummyMigrations,
ctx: context.Background(),
wantResult: map[string]bool{ wantResult: map[string]bool{
"init": true, "init": true,
"v2": true, "v2": true,
@ -80,25 +75,31 @@ func Test_migrations_Up(t *testing.T) {
}, },
{ {
name: "with fail", name: "with fail",
connectionString: "file::memory:",
migrations: withFail, migrations: withFail,
ctx: context.Background(),
wantErr: true, wantErr: true,
}, },
} }
ctx := context.Background()
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db, err := sql.Open("sqlite3", tt.connectionString) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil { if err != nil {
t.Errorf("unable to open database: %v", err) t.Errorf("unable to open database: %v", err)
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...) m.AddMigrations(tt.migrations...)
if err = m.Up(tt.ctx); (err != nil) != tt.wantErr { if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
} }
result, err := m.ExecutedMigrations(tt.ctx) result, err := m.ExecutedMigrations(ctx)
if err != nil { if err != nil {
t.Errorf("unable to get executed migrations: %v", err) t.Errorf("unable to get executed migrations: %v", err)
} }
@ -106,5 +107,6 @@ func Test_migrations_Up(t *testing.T) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result) t.Errorf("expected: %+v, got %v", tt.wantResult, result)
} }
}) })
})
} }
} }