package routing

import (
	"context"
	"fmt"
	"html/template"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"

	"github.com/matrix-org/dendrite/setup/config"
	userapi "github.com/matrix-org/dendrite/userapi/api"
)

func Test_validHMAC(t *testing.T) {
	type args struct {
		username string
		userHMAC string
		secret   string
	}
	tests := []struct {
		name    string
		args    args
		want    bool
		wantErr bool
	}{
		{
			name:    "invalid hmac",
			args:    args{},
			wantErr: false,
			want:    false,
		},
		// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
		//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
		//
		{
			name: "valid hmac",
			args: args{
				username: "@alice:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				secret:   "helloWorld",
			},
			want: true,
		},
		{
			name: "invalid hmac",
			args: args{
				username: "@bob:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				secret:   "helloWorld",
			},
			want: false,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
			if (err != nil) != tt.wantErr {
				t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if got != tt.want {
				t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
			}
		})
	}
}

type dummyAPI struct {
	usersConsent map[string]string
}

func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error {
	return nil
}

func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error {
	d.usersConsent[req.Localpart] = req.PolicyVersion
	return nil
}

func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error {
	res.PolicyVersion = "v2.0"
	return nil
}

const dummyTemplate = `
{{ if .HasConsented }}
Consent given.
{{ else }}
WithoutForm
	{{ if not .ReadOnly }}
    With Form.
    {{ end }}
{{ end }}`

func Test_consent(t *testing.T) {
	type args struct {
		username string
		userHMAC string
		version  string
		method   string
	}
	tests := []struct {
		name             string
		args             args
		wantRespCode     int
		wantBodyContains string
	}{
		{
			name: "not a userID, valid hmac",
			args: args{
				username: "notAuserID",
				userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d",
				version:  "v1.0",
				method:   http.MethodGet,
			},
			wantRespCode: http.StatusInternalServerError,
		},

		// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
		//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
		//
		{
			name: "valid hmac for alice GET, not consented",
			args: args{
				username: "@alice:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				version:  "v1.0",
				method:   http.MethodGet,
			},
			wantRespCode:     http.StatusOK,
			wantBodyContains: "With form",
		},
		{
			name: "alice consents successfully",
			args: args{
				username: "@alice:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				version:  "v1.0",
				method:   http.MethodPost,
			},
			wantRespCode:     http.StatusOK,
			wantBodyContains: "Consent given",
		},
		{
			name: "valid hmac for alice GET, new version",
			args: args{
				username: "@alice:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				version:  "v2.0",
				method:   http.MethodGet,
			},
			wantRespCode:     http.StatusOK,
			wantBodyContains: "With form",
		},
		{
			name: "no hmac provided for alice, read only should be displayed",
			args: args{
				username: "@alice:localhost",
				userHMAC: "",
				version:  "v1.0",
				method:   http.MethodGet,
			},
			wantRespCode:     http.StatusOK,
			wantBodyContains: "WithoutForm",
		},
		{
			name: "alice trying to get bobs status is forbidden",
			args: args{
				username: "@bob:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				version:  "v1.0",
				method:   http.MethodGet,
			},
			wantRespCode:     http.StatusForbidden,
			wantBodyContains: "forbidden",
		},
		{
			name: "alice trying to consent for bob is forbidden",
			args: args{
				username: "@bob:localhost",
				userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
				version:  "v1.0",
				method:   http.MethodPost,
			},
			wantRespCode:     http.StatusForbidden,
			wantBodyContains: "forbidden",
		},
	}

	userAPI := dummyAPI{
		usersConsent: map[string]string{},
	}
	consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate))
	consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate))
	userconsentOpts := config.UserConsentOptions{
		FormSecret: "helloWorld",
		Version:    "v1.0",
		Templates:  consentTemplates,
		BaseURL:    "http://localhost",
	}
	cfg := &config.ClientAPI{
		Matrix: &config.Global{
			UserConsentOptions: userconsentOpts,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s",
				userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC,
			)

			req := httptest.NewRequest(tt.args.method, url, nil)
			w := httptest.NewRecorder()

			consent(w, req, userAPI, cfg)

			resp := w.Result()
			body, err := ioutil.ReadAll(resp.Body)
			if err != nil {
				t.Fatalf("unable to read response body: %v", err)
			}
			defer resp.Body.Close()
			if resp.StatusCode != tt.wantRespCode {
				t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode)
			}

			if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) {
				t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body))
			}
		})
	}
}