package appservice_test

import (
	"context"
	"encoding/json"
	"fmt"
	"net"
	"net/http"
	"net/http/httptest"
	"path"
	"reflect"
	"regexp"
	"strings"
	"testing"
	"time"

	"github.com/matrix-org/dendrite/federationapi/statistics"
	"github.com/stretchr/testify/assert"

	"github.com/matrix-org/dendrite/appservice"
	"github.com/matrix-org/dendrite/appservice/api"
	"github.com/matrix-org/dendrite/appservice/consumers"
	"github.com/matrix-org/dendrite/internal/caching"
	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/roomserver"
	rsapi "github.com/matrix-org/dendrite/roomserver/api"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/dendrite/setup/jetstream"
	"github.com/matrix-org/dendrite/test"
	"github.com/matrix-org/dendrite/userapi"
	"github.com/matrix-org/gomatrixserverlib/spec"

	"github.com/matrix-org/dendrite/test/testrig"
)

var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) {
	return &statistics.ServerStatistics{}, nil
}

func TestAppserviceInternalAPI(t *testing.T) {

	// Set expected results
	existingProtocol := "irc"
	wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
	wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
	wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}}
	wantProtocolResult := map[string]api.ASProtocolResponse{
		existingProtocol: wantProtocolResponse,
	}

	// create a dummy AS url, handling some cases
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		switch {
		case strings.Contains(r.URL.Path, "location"):
			// Check if we've got an existing protocol, if so, return a proper response.
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		case strings.Contains(r.URL.Path, "user"):
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		case strings.Contains(r.URL.Path, "protocol"):
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode(nil); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		default:
			t.Logf("hit location: %s", r.URL.Path)
		}
	}))

	// The test cases to run
	runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) {
		t.Run("UserIDExists", func(t *testing.T) {
			testUserIDExists(t, testAPI, "@as-testing:test", true)
			testUserIDExists(t, testAPI, "@as1-testing:test", false)
		})

		t.Run("AliasExists", func(t *testing.T) {
			testAliasExists(t, testAPI, "@asroom-testing:test", true)
			testAliasExists(t, testAPI, "@asroom1-testing:test", false)
		})

		t.Run("Locations", func(t *testing.T) {
			testLocations(t, testAPI, existingProtocol, wantLocationResponse)
			testLocations(t, testAPI, "abc", nil)
		})

		t.Run("User", func(t *testing.T) {
			testUser(t, testAPI, existingProtocol, wantUserResponse)
			testUser(t, testAPI, "abc", nil)
		})

		t.Run("Protocols", func(t *testing.T) {
			testProtocol(t, testAPI, existingProtocol, wantProtocolResult)
			testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache
			testProtocol(t, testAPI, "", wantProtocolResult)               // tests getting all protocols
			testProtocol(t, testAPI, "abc", nil)
		})
	}

	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		cfg, ctx, close := testrig.CreateConfig(t, dbType)
		defer close()

		// Create a dummy application service
		as := &config.ApplicationService{
			ID:              "someID",
			URL:             srv.URL,
			ASToken:         "",
			HSToken:         "",
			SenderLocalpart: "senderLocalPart",
			NamespaceMap: map[string][]config.ApplicationServiceNamespace{
				"users":   {{RegexpObject: regexp.MustCompile("as-.*")}},
				"aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}},
			},
			Protocols: []string{existingProtocol},
		}
		as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation)
		cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as}
		t.Cleanup(func() {
			ctx.ShutdownDendrite()
			ctx.WaitForShutdown()
		})
		caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
		// Create required internal APIs
		natsInstance := jetstream.NATSInstance{}
		cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions)
		rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
		rsAPI.SetFederationAPI(nil, nil)
		usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
		asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI)

		runCases(t, asAPI)
	})
}

func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) {

	// Set expected results
	existingProtocol := "irc"
	wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
	wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
	wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}}

	// create a dummy AS url, handling some cases
	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		switch {
		case strings.Contains(r.URL.Path, "location"):
			// Check if we've got an existing protocol, if so, return a proper response.
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		case strings.Contains(r.URL.Path, "user"):
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		case strings.Contains(r.URL.Path, "protocol"):
			if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
				if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil {
					t.Fatalf("failed to encode response: %s", err)
				}
				return
			}
			if err := json.NewEncoder(w).Encode(nil); err != nil {
				t.Fatalf("failed to encode response: %s", err)
			}
			return
		default:
			t.Logf("hit location: %s", r.URL.Path)
		}
	}))

	tmpDir := t.TempDir()
	socket := path.Join(tmpDir, "socket")
	l, err := net.Listen("unix", socket)
	assert.NoError(t, err)
	_ = srv.Listener.Close()
	srv.Listener = l
	srv.Start()
	defer srv.Close()

	cfg, ctx, tearDown := testrig.CreateConfig(t, test.DBTypeSQLite)
	defer tearDown()

	// Create a dummy application service
	as := &config.ApplicationService{
		ID:              "someID",
		URL:             fmt.Sprintf("unix://%s", socket),
		ASToken:         "",
		HSToken:         "",
		SenderLocalpart: "senderLocalPart",
		NamespaceMap: map[string][]config.ApplicationServiceNamespace{
			"users":   {{RegexpObject: regexp.MustCompile("as-.*")}},
			"aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}},
		},
		Protocols: []string{existingProtocol},
	}
	as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation)
	cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as}

	t.Cleanup(func() {
		ctx.ShutdownDendrite()
		ctx.WaitForShutdown()
	})
	caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
	// Create required internal APIs
	natsInstance := jetstream.NATSInstance{}
	cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions)
	rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
	rsAPI.SetFederationAPI(nil, nil)
	usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
	asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI)

	t.Run("UserIDExists", func(t *testing.T) {
		testUserIDExists(t, asAPI, "@as-testing:test", true)
		testUserIDExists(t, asAPI, "@as1-testing:test", false)
	})

}

func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) {
	ctx := context.Background()
	userResp := &api.UserIDExistsResponse{}

	if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{
		UserID: userID,
	}, userResp); err != nil {
		t.Errorf("failed to get userID: %s", err)
	}
	if userResp.UserIDExists != wantExists {
		t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists)
	}
}

func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) {
	ctx := context.Background()
	aliasResp := &api.RoomAliasExistsResponse{}

	if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{
		Alias: alias,
	}, aliasResp); err != nil {
		t.Errorf("failed to get alias: %s", err)
	}
	if aliasResp.AliasExists != wantExists {
		t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists)
	}
}

func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) {
	ctx := context.Background()
	locationResp := &api.LocationResponse{}

	if err := asAPI.Locations(ctx, &api.LocationRequest{
		Protocol: proto,
	}, locationResp); err != nil {
		t.Errorf("failed to get locations: %s", err)
	}
	if !reflect.DeepEqual(locationResp.Locations, wantResult) {
		t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult)
	}
}

func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) {
	ctx := context.Background()
	userResp := &api.UserResponse{}

	if err := asAPI.User(ctx, &api.UserRequest{
		Protocol: proto,
	}, userResp); err != nil {
		t.Errorf("failed to get user: %s", err)
	}
	if !reflect.DeepEqual(userResp.Users, wantResult) {
		t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult)
	}
}

func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) {
	ctx := context.Background()
	protoResp := &api.ProtocolResponse{}

	if err := asAPI.Protocols(ctx, &api.ProtocolRequest{
		Protocol: proto,
	}, protoResp); err != nil {
		t.Errorf("failed to get Protocols: %s", err)
	}
	if !reflect.DeepEqual(protoResp.Protocols, wantResult) {
		t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult)
	}
}

// Tests that the roomserver consumer only receives one invite
func TestRoomserverConsumerOneInvite(t *testing.T) {

	alice := test.NewUser(t)
	bob := test.NewUser(t)
	room := test.NewRoom(t, alice)

	// Invite Bob
	room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
		"membership": "invite",
	}, test.WithStateKey(bob.ID))

	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
		defer closeDB()
		cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
		natsInstance := &jetstream.NATSInstance{}

		evChan := make(chan struct{})
		// create a dummy AS url, handling the events
		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			var txn consumers.ApplicationServiceTransaction
			err := json.NewDecoder(r.Body).Decode(&txn)
			if err != nil {
				t.Fatal(err)
			}
			for _, ev := range txn.Events {
				if ev.Type != spec.MRoomMember {
					continue
				}
				// Usually we would check the event content for the membership, but since
				// we only invited bob, this should be fine for this test.
				if ev.StateKey != nil && *ev.StateKey == bob.ID {
					evChan <- struct{}{}
				}
			}
		}))
		defer srv.Close()

		as := &config.ApplicationService{
			ID:              "someID",
			URL:             srv.URL,
			ASToken:         "",
			HSToken:         "",
			SenderLocalpart: "senderLocalPart",
			NamespaceMap: map[string][]config.ApplicationServiceNamespace{
				"users":   {{RegexpObject: regexp.MustCompile(bob.ID)}},
				"aliases": {{RegexpObject: regexp.MustCompile(room.ID)}},
			},
		}
		as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation)

		// Create a dummy application service
		cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as}

		caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
		// Create required internal APIs
		rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
		rsAPI.SetFederationAPI(nil, nil)
		usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
		// start the consumer
		appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI)

		// Create the room
		if err := rsapi.SendEvents(context.Background(), rsAPI, rsapi.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
			t.Fatalf("failed to send events: %v", err)
		}
		var seenInvitesForBob int
	waitLoop:
		for {
			select {
			case <-time.After(time.Millisecond * 50): // wait for the AS to process the events
				break waitLoop
			case <-evChan:
				seenInvitesForBob++
				if seenInvitesForBob != 1 {
					t.Fatalf("received unexpected invites: %d", seenInvitesForBob)
				}
			}
		}
		close(evChan)
	})
}