diff --git a/clientapi/auth/storage/devices/storage_test.go b/clientapi/auth/storage/devices/storage_test.go index e4ab229e0..a139c324a 100644 --- a/clientapi/auth/storage/devices/storage_test.go +++ b/clientapi/auth/storage/devices/storage_test.go @@ -5,67 +5,27 @@ import ( "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/stretchr/testify/assert" - "os" - "os/exec" - "strings" "testing" ) -var dataSource string -var insideCi = false -var insideDocker = false +const dataSourceName = "postgres://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable" -const dbName = "dendrite_device" - -func init() { - for _, val := range os.Environ() { - tokens := strings.Split(val, "=") - if tokens[0] == "CI" && tokens[1] == "true" { - insideCi = true - } - } - if !insideCi { - if _, err := os.Open("/.dockerenv"); err == nil { - insideDocker = true - } - } - - if insideCi { - dataSource = fmt.Sprintf("postgres://postgres@localhost/%s?sslmode=disable", dbName) - } else if insideDocker { - dataSource = fmt.Sprintf("postgres://dendrite:itsasecret@postgres/%s?sslmode=disable", dbName) - } else { - dataSource = fmt.Sprintf("postgres://dendrite:itsasecret@localhost:15432/%s?sslmode=disable", dbName) - } - - if insideCi { - database := "dendrite_device" - cmd := exec.Command("psql", "postgres") - cmd.Stdin = strings.NewReader( - fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database), - ) - // Send stdout and stderr to our stderr so that we see error messages from - // the psql process - cmd.Stdout = os.Stderr - cmd.Stderr = os.Stderr - _ = cmd.Run() - } -} +//const dataSourceLocal = "postgres://dendrite:itsasecret@localhost:15432/dendrite_device?sslmode=disable" type deviceSpec struct { localPart string - devID string + devId string accessToken string displayName string } func TestDatabase_GetDevicesByLocalpart(t *testing.T) { - db, err := NewDatabase(dataSource, "localhost") + db, err := NewDatabase(dataSourceName, "localhost") assert.Nil(t, err) devSpec := deviceSpec{ localPart: "get-device-test-local-part", - devID: "get-device-test-device-id", + devId: "get-device-test-device-id", accessToken: "get-device-test-access-token", displayName: "get-device-test-display-name", } @@ -85,7 +45,7 @@ func TestDatabase_GetDevicesByLocalpart(t *testing.T) { func TestDatabase_CreateDevice(t *testing.T) { devSpec := deviceSpec{ localPart: "create-test-local-part", - devID: "create-test-device-id", + devId: "create-test-device-id", accessToken: "create-test-access-token", displayName: "create-test-display-name", } @@ -99,21 +59,22 @@ func TestDatabase_CreateDevice(t *testing.T) { // create a number of device entries in the database, using ``devSpecScheme`` as the creation pattern. func createTestDevice(devSpecScheme *deviceSpec, count int) (devices []*authtypes.Device, err error) { - db, err := NewDatabase(dataSource, "localhost") + db, err := NewDatabase(dataSourceName, "localhost") if err != nil { fmt.Println(err) } for i := 0; i < count; i++ { - devID := fmt.Sprintf("%s%d", devSpecScheme.devID, i) + devId := fmt.Sprintf("%s%d", devSpecScheme.devId, i) displayName := fmt.Sprintf("%s%d", devSpecScheme.displayName, i) if device, err := db.CreateDevice( context.Background(), fmt.Sprintf("%s%d", devSpecScheme.localPart, i), - &devID, + &devId, fmt.Sprintf("%s%d", devSpecScheme.accessToken, i), &displayName); err != nil { fmt.Println(err) + return nil, err } else { devices = append(devices, device) }