Update tests

This commit is contained in:
Till Faelligen 2022-06-16 10:00:21 +02:00
parent 99b657c7ea
commit 91c6cde48b

View file

@ -8,6 +8,7 @@ import (
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3"
)
@ -55,23 +56,17 @@ var failMigration = sqlutil.Migration{
}
func Test_migrations_Up(t *testing.T) {
withFail := make([]sqlutil.Migration, len(dummyMigrations))
copy(withFail, dummyMigrations)
withFail = append(withFail, failMigration)
withFail := append(dummyMigrations, failMigration)
tests := []struct {
name string
connectionString string
ctx context.Context
migrations []sqlutil.Migration
wantResult map[string]bool
wantErr bool
name string
migrations []sqlutil.Migration
wantResult map[string]bool
wantErr bool
}{
{
name: "dummy migration",
connectionString: "file::memory:",
migrations: dummyMigrations,
ctx: context.Background(),
name: "dummy migration",
migrations: dummyMigrations,
wantResult: map[string]bool{
"init": true,
"v2": true,
@ -79,32 +74,39 @@ func Test_migrations_Up(t *testing.T) {
},
},
{
name: "with fail",
connectionString: "file::memory:",
migrations: withFail,
ctx: context.Background(),
wantErr: true,
name: "with fail",
migrations: withFail,
wantErr: true,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := sql.Open("sqlite3", tt.connectionString)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err = m.Up(tt.ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(tt.ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
})
}
}