// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package relayapi_test

import (
	"crypto/ed25519"
	"encoding/hex"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/gorilla/mux"
	"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
	"github.com/matrix-org/dendrite/relayapi"
	"github.com/matrix-org/dendrite/test"
	"github.com/matrix-org/dendrite/test/testrig"
	"github.com/matrix-org/gomatrixserverlib"
	"github.com/stretchr/testify/assert"
)

func TestCreateNewRelayInternalAPI(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		base, close := testrig.CreateBaseDendrite(t, dbType)
		defer close()

		relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true)
		assert.NotNil(t, relayAPI)
	})
}

func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		base, close := testrig.CreateBaseDendrite(t, dbType)
		if dbType == test.DBTypeSQLite {
			base.Cfg.RelayAPI.Database.ConnectionString = "file:"
		} else {
			base.Cfg.RelayAPI.Database.ConnectionString = "test"
		}
		defer close()

		assert.Panics(t, func() {
			relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true)
		})
	})
}

func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		base, close := testrig.CreateBaseDendrite(t, dbType)
		defer close()

		assert.Panics(t, func() {
			relayapi.AddPublicRoutes(base, nil, nil)
		})
	})
}

func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request {
	_, sk, _ := ed25519.GenerateKey(nil)
	keyID := signing.KeyID
	pk := sk.Public().(ed25519.PublicKey)
	origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
	req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID)
	content := gomatrixserverlib.RelayEntry{EntryID: 0}
	req.SetContent(content)
	req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
	httpreq, _ := req.HTTPRequest()
	vars := map[string]string{"userID": userID}
	httpreq = mux.SetURLVars(httpreq, vars)
	return httpreq
}

type sendRelayContent struct {
	PDUs []json.RawMessage       `json:"pdus"`
	EDUs []gomatrixserverlib.EDU `json:"edus"`
}

func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request {
	_, sk, _ := ed25519.GenerateKey(nil)
	keyID := signing.KeyID
	pk := sk.Public().(ed25519.PublicKey)
	origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
	req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID)
	content := sendRelayContent{}
	req.SetContent(content)
	req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
	httpreq, _ := req.HTTPRequest()
	vars := map[string]string{"userID": userID, "txnID": txnID}
	httpreq = mux.SetURLVars(httpreq, vars)
	return httpreq
}

func TestCreateRelayPublicRoutes(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		base, close := testrig.CreateBaseDendrite(t, dbType)
		defer close()

		relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true)
		assert.NotNil(t, relayAPI)

		serverKeyAPI := &signing.YggdrasilKeys{}
		keyRing := serverKeyAPI.KeyRing()
		relayapi.AddPublicRoutes(base, keyRing, relayAPI)

		testCases := []struct {
			name     string
			req      *http.Request
			wantCode int
		}{
			{
				name:     "relay_txn invalid user id",
				req:      createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"),
				wantCode: 400,
			},
			{
				name:     "relay_txn valid user id",
				req:      createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
				wantCode: 200,
			},
			{
				name:     "send_relay invalid user id",
				req:      createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"),
				wantCode: 400,
			},
			{
				name:     "send_relay valid user id",
				req:      createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
				wantCode: 200,
			},
		}

		for _, tc := range testCases {
			w := httptest.NewRecorder()
			base.PublicFederationAPIMux.ServeHTTP(w, tc.req)
			if w.Code != tc.wantCode {
				t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
			}
		}
	})
}

func TestDisableRelayPublicRoutes(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		base, close := testrig.CreateBaseDendrite(t, dbType)
		defer close()

		relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, false)
		assert.NotNil(t, relayAPI)

		serverKeyAPI := &signing.YggdrasilKeys{}
		keyRing := serverKeyAPI.KeyRing()
		relayapi.AddPublicRoutes(base, keyRing, relayAPI)

		testCases := []struct {
			name     string
			req      *http.Request
			wantCode int
		}{
			{
				name:     "relay_txn valid user id",
				req:      createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
				wantCode: 404,
			},
			{
				name:     "send_relay valid user id",
				req:      createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
				wantCode: 404,
			},
		}

		for _, tc := range testCases {
			w := httptest.NewRecorder()
			base.PublicFederationAPIMux.ServeHTTP(w, tc.req)
			if w.Code != tc.wantCode {
				t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
			}
		}
	})
}