package deltas

import (
	"context"
	"database/sql"
	"fmt"
	"strings"

	"github.com/lib/pq"
	"github.com/matrix-org/gomatrixserverlib/spec"
	"github.com/sirupsen/logrus"
)

var serverNamesTables = []string{
	"userapi_accounts",
	"userapi_account_datas",
	"userapi_devices",
	"userapi_notifications",
	"userapi_openid_tokens",
	"userapi_profiles",
	"userapi_pushers",
	"userapi_threepids",
}

// These tables have a PRIMARY KEY constraint which we need to drop so
// that we can recreate a new unique index that contains the server name.
var serverNamesDropPK = []string{
	"userapi_accounts",
	"userapi_account_datas",
	"userapi_profiles",
}

// These indices are out of date so let's drop them. They will get recreated
// automatically.
var serverNamesDropIndex = []string{
	"userapi_pusher_localpart_idx",
	"userapi_pusher_app_id_pushkey_localpart_idx",
}

// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.

func UpServerNames(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error {
	for _, table := range serverNamesTables {
		q := fmt.Sprintf(
			"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
			pq.QuoteIdentifier(table),
		)
		var c int
		if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
			continue
		}
		q = fmt.Sprintf(
			"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
			pq.QuoteIdentifier(table),
		)
		if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
			logrus.Infof("Table %s already has column, skipping", table)
			continue
		}
		if c == 0 {
			q = fmt.Sprintf(
				"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
				pq.QuoteIdentifier(table),
			)
			if _, err := tx.ExecContext(ctx, q); err != nil {
				return fmt.Errorf("add server name to %q error: %w", table, err)
			}
		}
	}
	for _, table := range serverNamesDropPK {
		q := fmt.Sprintf(
			"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
			pq.QuoteIdentifier(table),
		)
		var c int
		var sql string
		if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
			continue
		}
		q = fmt.Sprintf(`
			%s; -- create temporary table
			INSERT INTO %s SELECT * FROM %s; -- copy data
			DROP TABLE %s; -- drop original table
			ALTER TABLE %s RENAME TO %s; -- rename new table
		`,
			strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
			table+"_tmp", table, // copy data
			table,               // drop original table
			table+"_tmp", table, // rename new table
		)
		if _, err := tx.ExecContext(ctx, q); err != nil {
			return fmt.Errorf("drop PK from %q error: %w", table, err)
		}
	}
	for _, index := range serverNamesDropIndex {
		q := fmt.Sprintf(
			"DROP INDEX IF EXISTS %s;",
			pq.QuoteIdentifier(index),
		)
		if _, err := tx.ExecContext(ctx, q); err != nil {
			return fmt.Errorf("drop index %q error: %w", index, err)
		}
	}
	return nil
}