// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sqlite3

import (
	"context"
	"database/sql"
	"strings"

	"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
	"github.com/matrix-org/dendrite/internal"
	"github.com/matrix-org/dendrite/internal/sqlutil"
)

const membershipSchema = `
-- Stores data about users memberships to rooms.
CREATE TABLE IF NOT EXISTS account_memberships (
    -- The Matrix user ID localpart for the member
    localpart TEXT NOT NULL,
    -- The room this user is a member of
    room_id TEXT NOT NULL,
    -- The ID of the join membership event
    event_id TEXT NOT NULL,

    -- A user can only be member of a room once
    PRIMARY KEY (localpart, room_id),

		UNIQUE (event_id)
);
`

const insertMembershipSQL = `
	INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
	ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
`

const selectMembershipsByLocalpartSQL = "" +
	"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"

const selectMembershipInRoomByLocalpartSQL = "" +
	"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"

const selectRoomIDsByLocalPartSQL = "" +
	"SELECT room_id FROM account_memberships WHERE localpart = $1"

const deleteMembershipsByEventIDsSQL = "" +
	"DELETE FROM account_memberships WHERE event_id IN ($1)"

type membershipStatements struct {
	insertMembershipStmt                  *sql.Stmt
	selectMembershipInRoomByLocalpartStmt *sql.Stmt
	selectMembershipsByLocalpartStmt      *sql.Stmt
	selectRoomIDsByLocalPartStmt          *sql.Stmt
}

func (s *membershipStatements) prepare(db *sql.DB) (err error) {
	_, err = db.Exec(membershipSchema)
	if err != nil {
		return
	}
	if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
		return
	}
	if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
		return
	}
	if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
		return
	}
	if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
		return
	}
	return
}

func (s *membershipStatements) insertMembership(
	ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
	stmt := txn.Stmt(s.insertMembershipStmt)
	_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
	return
}

func (s *membershipStatements) deleteMembershipsByEventIDs(
	ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
	sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
	iEventIDs := make([]interface{}, len(eventIDs))
	for i, e := range eventIDs {
		iEventIDs[i] = e
	}
	_, err = txn.ExecContext(ctx, sqlStr, iEventIDs...)
	return
}

func (s *membershipStatements) selectMembershipInRoomByLocalpart(
	ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
	membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
	stmt := s.selectMembershipInRoomByLocalpartStmt
	err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)

	return membership, err
}

func (s *membershipStatements) selectMembershipsByLocalpart(
	ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
	stmt := s.selectMembershipsByLocalpartStmt
	rows, err := stmt.QueryContext(ctx, localpart)
	if err != nil {
		return
	}

	memberships = []authtypes.Membership{}

	defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed")
	for rows.Next() {
		var m authtypes.Membership
		m.Localpart = localpart
		if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
			return nil, err
		}
		memberships = append(memberships, m)
	}

	return
}
func (s *membershipStatements) selectRoomIDsByLocalPart(
	ctx context.Context, localPart string,
) ([]string, error) {
	stmt := s.selectRoomIDsByLocalPartStmt
	rows, err := stmt.QueryContext(ctx, localPart)
	if err != nil {
		return nil, err
	}
	roomIDs := []string{}
	defer rows.Close() // nolint: errcheck
	for rows.Next() {
		var roomID string
		if err = rows.Scan(&roomID); err != nil {
			return nil, err
		}
		roomIDs = append(roomIDs, roomID)
	}
	return roomIDs, rows.Err()
}