package sqlutil_test

import (
	"reflect"
	"testing"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/dendrite/setup/process"
	"github.com/matrix-org/dendrite/test"
)

func TestConnectionManager(t *testing.T) {

	t.Run("component defined connection string", func(t *testing.T) {
		test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
			conStr, close := test.PrepareDBConnectionString(t, dbType)
			t.Cleanup(close)
			cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})

			dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}
			db, writer, err := cm.Connection(dbProps)
			if err != nil {
				t.Fatal(err)
			}

			switch dbType {
			case test.DBTypeSQLite:
				_, ok := writer.(*sqlutil.ExclusiveWriter)
				if !ok {
					t.Fatalf("expected exclusive writer")
				}
			case test.DBTypePostgres:
				_, ok := writer.(*sqlutil.DummyWriter)
				if !ok {
					t.Fatalf("expected dummy writer")
				}
			}

			// reuse existing connection
			db2, writer2, err := cm.Connection(dbProps)
			if err != nil {
				t.Fatal(err)
			}
			if !reflect.DeepEqual(db, db2) {
				t.Fatalf("expected database connection to be reused")
			}
			if !reflect.DeepEqual(writer, writer2) {
				t.Fatalf("expected database writer to be reused")
			}
		})
	})

	t.Run("global connection pool", func(t *testing.T) {
		test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
			conStr, close := test.PrepareDBConnectionString(t, dbType)
			t.Cleanup(close)
			cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})

			dbProps := &config.DatabaseOptions{}
			db, writer, err := cm.Connection(dbProps)
			if err != nil {
				t.Fatal(err)
			}

			switch dbType {
			case test.DBTypeSQLite:
				_, ok := writer.(*sqlutil.ExclusiveWriter)
				if !ok {
					t.Fatalf("expected exclusive writer")
				}
			case test.DBTypePostgres:
				_, ok := writer.(*sqlutil.DummyWriter)
				if !ok {
					t.Fatalf("expected dummy writer")
				}
			}

			// reuse existing connection
			db2, writer2, err := cm.Connection(dbProps)
			if err != nil {
				t.Fatal(err)
			}
			if !reflect.DeepEqual(db, db2) {
				t.Fatalf("expected database connection to be reused")
			}
			if !reflect.DeepEqual(writer, writer2) {
				t.Fatalf("expected database writer to be reused")
			}
		})
	})

	t.Run("shutdown", func(t *testing.T) {
		test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
			conStr, close := test.PrepareDBConnectionString(t, dbType)
			t.Cleanup(close)

			processCtx := process.NewProcessContext()
			cm := sqlutil.NewConnectionManager(processCtx, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})

			dbProps := &config.DatabaseOptions{}
			_, _, err := cm.Connection(dbProps)
			if err != nil {
				t.Fatal(err)
			}

			processCtx.ShutdownDendrite()
			processCtx.WaitForComponentsToFinish()
		})
	})

	// test invalid connection string configured
	cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
	_, _, err := cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"})
	if err == nil {
		t.Fatal("expected an error but got none")
	}
}