mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-07 06:03:09 -06:00
Update tests
This commit is contained in:
parent
99b657c7ea
commit
91c6cde48b
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
|
|
@ -55,23 +56,17 @@ var failMigration = sqlutil.Migration{
|
|||
}
|
||||
|
||||
func Test_migrations_Up(t *testing.T) {
|
||||
withFail := make([]sqlutil.Migration, len(dummyMigrations))
|
||||
copy(withFail, dummyMigrations)
|
||||
withFail = append(withFail, failMigration)
|
||||
withFail := append(dummyMigrations, failMigration)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connectionString string
|
||||
ctx context.Context
|
||||
migrations []sqlutil.Migration
|
||||
wantResult map[string]bool
|
||||
wantErr bool
|
||||
name string
|
||||
migrations []sqlutil.Migration
|
||||
wantResult map[string]bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "dummy migration",
|
||||
connectionString: "file::memory:",
|
||||
migrations: dummyMigrations,
|
||||
ctx: context.Background(),
|
||||
name: "dummy migration",
|
||||
migrations: dummyMigrations,
|
||||
wantResult: map[string]bool{
|
||||
"init": true,
|
||||
"v2": true,
|
||||
|
|
@ -79,32 +74,39 @@ func Test_migrations_Up(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
name: "with fail",
|
||||
connectionString: "file::memory:",
|
||||
migrations: withFail,
|
||||
ctx: context.Background(),
|
||||
wantErr: true,
|
||||
name: "with fail",
|
||||
migrations: withFail,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", tt.connectionString)
|
||||
if err != nil {
|
||||
t.Errorf("unable to open database: %v", err)
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(tt.migrations...)
|
||||
if err = m.Up(tt.ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
result, err := m.ExecutedMigrations(tt.ctx)
|
||||
if err != nil {
|
||||
t.Errorf("unable to get executed migrations: %v", err)
|
||||
}
|
||||
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
|
||||
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
|
||||
}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
conStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
defer close()
|
||||
driverName := "sqlite3"
|
||||
if dbType == test.DBTypePostgres {
|
||||
driverName = "postgres"
|
||||
}
|
||||
db, err := sql.Open(driverName, conStr)
|
||||
if err != nil {
|
||||
t.Errorf("unable to open database: %v", err)
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(tt.migrations...)
|
||||
if err = m.Up(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
result, err := m.ExecutedMigrations(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("unable to get executed migrations: %v", err)
|
||||
}
|
||||
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
|
||||
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue