Update tests

This commit is contained in:
Till Faelligen 2022-06-16 10:00:21 +02:00
parent 99b657c7ea
commit 91c6cde48b

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -55,23 +56,17 @@ var failMigration = sqlutil.Migration{
} }
func Test_migrations_Up(t *testing.T) { func Test_migrations_Up(t *testing.T) {
withFail := make([]sqlutil.Migration, len(dummyMigrations)) withFail := append(dummyMigrations, failMigration)
copy(withFail, dummyMigrations)
withFail = append(withFail, failMigration)
tests := []struct { tests := []struct {
name string name string
connectionString string migrations []sqlutil.Migration
ctx context.Context wantResult map[string]bool
migrations []sqlutil.Migration wantErr bool
wantResult map[string]bool
wantErr bool
}{ }{
{ {
name: "dummy migration", name: "dummy migration",
connectionString: "file::memory:", migrations: dummyMigrations,
migrations: dummyMigrations,
ctx: context.Background(),
wantResult: map[string]bool{ wantResult: map[string]bool{
"init": true, "init": true,
"v2": true, "v2": true,
@ -79,32 +74,39 @@ func Test_migrations_Up(t *testing.T) {
}, },
}, },
{ {
name: "with fail", name: "with fail",
connectionString: "file::memory:", migrations: withFail,
migrations: withFail, wantErr: true,
ctx: context.Background(),
wantErr: true,
}, },
} }
ctx := context.Background()
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db, err := sql.Open("sqlite3", tt.connectionString) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
if err != nil { conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Errorf("unable to open database: %v", err) defer close()
} driverName := "sqlite3"
m := sqlutil.NewMigrator(db) if dbType == test.DBTypePostgres {
m.AddMigrations(tt.migrations...) driverName = "postgres"
if err = m.Up(tt.ctx); (err != nil) != tt.wantErr { }
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr) db, err := sql.Open(driverName, conStr)
} if err != nil {
result, err := m.ExecutedMigrations(tt.ctx) t.Errorf("unable to open database: %v", err)
if err != nil { }
t.Errorf("unable to get executed migrations: %v", err) m := sqlutil.NewMigrator(db)
} m.AddMigrations(tt.migrations...)
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) { if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("expected: %+v, got %v", tt.wantResult, result) t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
} }
result, err := m.ExecutedMigrations(ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
}) })
} }
} }