package helpers

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"sort"

	"github.com/matrix-org/gomatrixserverlib"
	"github.com/matrix-org/gomatrixserverlib/spec"
	"github.com/matrix-org/util"

	"github.com/matrix-org/dendrite/roomserver/api"
	"github.com/matrix-org/dendrite/roomserver/auth"
	"github.com/matrix-org/dendrite/roomserver/state"
	"github.com/matrix-org/dendrite/roomserver/storage"
	"github.com/matrix-org/dendrite/roomserver/storage/shared"
	"github.com/matrix-org/dendrite/roomserver/storage/tables"
	"github.com/matrix-org/dendrite/roomserver/types"
)

// TODO: temporary package which has helper functions used by both internal/perform packages.
// Move these to a more sensible place.

func UpdateToInviteMembership(
	mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
	roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) {
	// We may have already sent the invite to the user, either because we are
	// reprocessing this event, or because the we received this invite from a
	// remote server via the federation invite API. In those cases we don't need
	// to send the event.
	needsSending, retired, err := mu.Update(tables.MembershipStateInvite, add)
	if err != nil {
		return nil, err
	}
	if needsSending {
		// We notify the consumers using a special event even though we will
		// notify them about the change in current state as part of the normal
		// room event stream. This ensures that the consumers only have to
		// consider a single stream of events when determining whether a user
		// is invited, rather than having to combine multiple streams themselves.
		updates = append(updates, api.OutputEvent{
			Type: api.OutputTypeNewInviteEvent,
			NewInviteEvent: &api.OutputNewInviteEvent{
				Event:       &types.HeaderedEvent{PDU: add.PDU},
				RoomVersion: roomVersion,
			},
		})
	}
	for _, eventID := range retired {
		updates = append(updates, api.OutputEvent{
			Type: api.OutputTypeRetireInviteEvent,
			RetireInviteEvent: &api.OutputRetireInviteEvent{
				EventID:          eventID,
				RoomID:           add.RoomID(),
				Membership:       spec.Join,
				RetiredByEventID: add.EventID(),
				TargetSenderID:   spec.SenderID(*add.StateKey()),
			},
		})
	}
	return updates, nil
}

// IsServerCurrentlyInRoom checks if a server is in a given room, based on the room
// memberships. If the servername is not supplied then the local server will be
// checked instead using a faster code path.
// TODO: This should probably be replaced by an API call.
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
	info, err := db.RoomInfo(ctx, roomID)
	if err != nil {
		return false, err
	}
	if info == nil {
		return false, fmt.Errorf("unknown room %s", roomID)
	}

	if serverName == "" {
		return db.GetLocalServerInRoom(ctx, info.RoomNID)
	}

	eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
	if err != nil {
		return false, err
	}

	events, err := db.Events(ctx, info.RoomVersion, eventNIDs)
	if err != nil {
		return false, err
	}
	gmslEvents := make([]gomatrixserverlib.PDU, len(events))
	for i := range events {
		gmslEvents[i] = events[i].PDU
	}
	return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
}

func IsInvitePending(
	ctx context.Context, db storage.Database,
	roomID string, senderID spec.SenderID,
) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
	// Look up the room NID for the supplied room ID.
	info, err := db.RoomInfo(ctx, roomID)
	if err != nil {
		return false, "", "", nil, fmt.Errorf("r.DB.RoomInfo: %w", err)
	}
	if info == nil {
		return false, "", "", nil, fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
	}

	// Look up the state key NID for the supplied user ID.
	targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
	if err != nil {
		return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
	}
	targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
	if !targetUserFound {
		return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
	}

	// Let's see if we have an event active for the user in the room. If
	// we do then it will contain a server name that we can direct the
	// send_leave to.
	senderUserNIDs, eventIDs, eventJSON, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
	if err != nil {
		return false, "", "", nil, fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
	}
	if len(senderUserNIDs) == 0 {
		return false, "", "", nil, nil
	}
	userNIDToEventID := make(map[types.EventStateKeyNID]string)
	for i, nid := range senderUserNIDs {
		userNIDToEventID[nid] = eventIDs[i]
	}

	// Look up the user ID from the NID.
	senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
	if err != nil {
		return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeys: %w", err)
	}
	if len(senderUsers) == 0 {
		return false, "", "", nil, fmt.Errorf("no senderUsers")
	}

	senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
	if !senderUserFound {
		return false, "", "", nil, fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
	}

	verImpl, err := gomatrixserverlib.GetRoomVersion(info.RoomVersion)
	if err != nil {
		return false, "", "", nil, err
	}

	event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)

	return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
}

// GetMembershipsAtState filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
	ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {

	var eventNIDs types.EventNIDs
	for _, entry := range stateEntries {
		// Filter the events to retrieve to only keep the membership events
		if entry.EventTypeNID == types.MRoomMemberNID {
			eventNIDs = append(eventNIDs, entry.EventNID)
		}
	}

	// There are no events to get, don't bother asking the database
	if len(eventNIDs) == 0 {
		return []types.Event{}, nil
	}

	sort.Sort(eventNIDs)
	util.Unique(eventNIDs)

	// Get all of the events in this state
	if roomInfo == nil {
		return nil, types.ErrorInvalidRoomInfo
	}
	stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
	if err != nil {
		return nil, err
	}

	if !joinedOnly {
		return stateEvents, nil
	}

	// Filter the events to only keep the "join" membership events
	var events []types.Event
	for _, event := range stateEvents {
		membership, err := event.Membership()
		if err != nil {
			return nil, err
		}

		if membership == spec.Join {
			events = append(events, event)
		}
	}

	return events, nil
}

func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
	roomState := state.NewStateResolution(db, info, querier)
	// Lookup the event NID
	eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
	if err != nil {
		return nil, err
	}
	eventIDs := []string{eIDs[eventNID]}

	prevState, err := db.StateAtEventIDs(ctx, eventIDs)
	if err != nil {
		return nil, err
	}

	// Fetch the state as it was when this event was fired
	return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}

func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
	roomState := state.NewStateResolution(db, info, querier)
	// Fetch the state as it was when this event was fired
	return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}

func LoadEvents(
	ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.PDU, error) {
	if roomInfo == nil {
		return nil, types.ErrorInvalidRoomInfo
	}
	stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
	if err != nil {
		return nil, err
	}

	result := make([]gomatrixserverlib.PDU, len(stateEvents))
	for i := range stateEvents {
		result[i] = stateEvents[i].PDU
	}
	return result, nil
}

func LoadStateEvents(
	ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.PDU, error) {
	eventNIDs := make([]types.EventNID, len(stateEntries))
	for i := range stateEntries {
		eventNIDs[i] = stateEntries[i].EventNID
	}
	return LoadEvents(ctx, db, roomInfo, eventNIDs)
}

func CheckServerAllowedToSeeEvent(
	ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
) (bool, error) {
	stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
	switch err {
	case nil:
		// No error, so continue normally
	case tables.OptimisationNotSupportedError:
		// The database engine didn't support this optimisation, so fall back to using
		// the old and slow method
		stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
		if err != nil {
			return false, err
		}
	default:
		switch err.(type) {
		case types.MissingStateError:
			// If there's no state then we assume it's open visibility, as Synapse does:
			// https://github.com/matrix-org/synapse/blob/aec87a0f9369a3015b2a53469f88d1de274e8b71/synapse/visibility.py#L654-L655
			return true, nil
		default:
			// Something else went wrong
			return false, err
		}
	}
	return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
}

func slowGetHistoryVisibilityState(
	ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]gomatrixserverlib.PDU, error) {
	roomState := state.NewStateResolution(db, info, querier)
	stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, nil
		}
		return nil, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
	}

	// Extract all of the event state key NIDs from the room state.
	var stateKeyNIDs []types.EventStateKeyNID
	for _, entry := range stateEntries {
		stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
	}

	// Then request those state key NIDs from the database.
	stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
	if err != nil {
		return nil, fmt.Errorf("db.EventStateKeys: %w", err)
	}

	// If the event state key doesn't match the given servername
	// then we'll filter it out. This does preserve state keys that
	// are "" since these will contain history visibility etc.
	validRoomID, err := spec.NewRoomID(roomID)
	if err != nil {
		return nil, err
	}
	for nid, key := range stateKeys {
		if key != "" {
			userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
			if err == nil && userID != nil {
				if userID.Domain() != serverName {
					delete(stateKeys, nid)
				}
			}
		}
	}

	// Now filter through all of the state events for the room.
	// If the state key NID appears in the list of valid state
	// keys then we'll add it to the list of filtered entries.
	var filteredEntries []types.StateEntry
	for _, entry := range stateEntries {
		if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
			filteredEntries = append(filteredEntries, entry)
		}
	}

	if len(filteredEntries) == 0 {
		return nil, nil
	}

	return LoadStateEvents(ctx, db, info, filteredEntries)
}

// TODO: Remove this when we have tests to assert correctness of this function
func ScanEventTree(
	ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
	serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]types.EventNID, map[string]struct{}, error) {
	var resultNIDs []types.EventNID
	var err error
	var allowed bool
	var events []types.Event
	var next []string
	var pre string

	// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
	// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
	// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
	// duplicate events being sent in response to /backfill requests.
	initialIgnoreList := make(map[string]bool, len(visited))
	for k, v := range visited {
		initialIgnoreList[k] = v
	}

	resultNIDs = make([]types.EventNID, 0, limit)

	var checkedServerInRoom bool
	var isServerInRoom bool
	redactEventIDs := make(map[string]struct{})

	// Loop through the event IDs to retrieve the requested events and go
	// through the whole tree (up to the provided limit) using the events'
	// "prev_event" key.
BFSLoop:
	for len(front) > 0 {
		// Prevent unnecessary allocations: reset the slice only when not empty.
		if len(next) > 0 {
			next = make([]string, 0)
		}
		// Retrieve the events to process from the database.
		events, err = db.EventsFromIDs(ctx, info, front)
		if err != nil {
			return resultNIDs, redactEventIDs, err
		}

		if !checkedServerInRoom && len(events) > 0 {
			// It's nasty that we have to extract the room ID from an event, but many federation requests
			// only talk in event IDs, no room IDs at all (!!!)
			ev := events[0]
			isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
			if err != nil {
				util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
			}
			checkedServerInRoom = true
		}

		for _, ev := range events {
			// Break out of the loop if the provided limit is reached.
			if len(resultNIDs) == limit {
				break BFSLoop
			}

			if !initialIgnoreList[ev.EventID()] {
				// Update the list of events to retrieve.
				resultNIDs = append(resultNIDs, ev.EventNID)
			}
			// Loop through the event's parents.
			for _, pre = range ev.PrevEventIDs() {
				// Only add an event to the list of next events to process if it
				// hasn't been seen before.
				if !visited[pre] {
					visited[pre] = true
					allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
					if err != nil {
						util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
							"Error checking if allowed to see event",
						)
						// drop the error, as we will often error at the DB level if we don't have the prev_event itself. Let's
						// just return what we have.
						return resultNIDs, redactEventIDs, nil
					}

					// If the event hasn't been seen before and the HS
					// requesting to retrieve it is allowed to do so, add it to
					// the list of events to retrieve.
					next = append(next, pre)
					if !allowed {
						util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
						redactEventIDs[pre] = struct{}{}
					}
				}
			}
		}
		// Repeat the same process with the parent events we just processed.
		front = next
	}

	return resultNIDs, redactEventIDs, err
}

func QueryLatestEventsAndState(
	ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
	request *api.QueryLatestEventsAndStateRequest,
	response *api.QueryLatestEventsAndStateResponse,
) error {
	roomInfo, err := db.RoomInfo(ctx, request.RoomID)
	if err != nil {
		return err
	}
	if roomInfo == nil || roomInfo.IsStub() {
		response.RoomExists = false
		return nil
	}

	roomState := state.NewStateResolution(db, roomInfo, querier)
	response.RoomExists = true
	response.RoomVersion = roomInfo.RoomVersion

	var currentStateSnapshotNID types.StateSnapshotNID
	response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
		db.LatestEventIDs(ctx, roomInfo.RoomNID)
	if err != nil {
		return err
	}

	var stateEntries []types.StateEntry
	if len(request.StateToFetch) == 0 {
		// Look up all room state.
		stateEntries, err = roomState.LoadStateAtSnapshot(
			ctx, currentStateSnapshotNID,
		)
	} else {
		// Look up the current state for the requested tuples.
		stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples(
			ctx, currentStateSnapshotNID, request.StateToFetch,
		)
	}
	if err != nil {
		return err
	}

	stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries)
	if err != nil {
		return err
	}

	for _, event := range stateEvents {
		response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event})
	}

	return nil
}