diff --git a/clientapi/routing/consent_tracking.go b/clientapi/routing/consent_tracking.go index 5947855d4..d3c1e3947 100644 --- a/clientapi/routing/consent_tracking.go +++ b/clientapi/routing/consent_tracking.go @@ -1,10 +1,14 @@ package routing import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "net/http" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -25,23 +29,93 @@ func consent(userAPI userapi.UserInternalAPI, cfg *config.ClientAPI) http.Handle _, _ = writer.Write([]byte("consent tracking is disabled")) return } + // The data used to populate the /consent request + data := constentTemplateData{ + User: req.FormValue("u"), + Version: req.FormValue("v"), + UserHMAC: req.FormValue("h"), + } switch req.Method { case http.MethodGet: - // The data used to populate the /consent request - data := constentTemplateData{ - User: req.FormValue("u"), - Version: req.FormValue("v"), - UserHMAC: req.FormValue("h"), - } // display the privacy policy without a form data.PublicVersion = data.User == "" || data.UserHMAC == "" || data.Version == "" + // let's see if the user already consented to the current version + if !data.PublicVersion { + res := &userapi.QueryPolicyVersionResponse{} + localPart, _, err := gomatrixserverlib.SplitID('@', data.User) + if err != nil { + logrus.WithError(err).Error("unable to print consent template") + return + } + if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{ + LocalPart: localPart, + }, res); err != nil { + logrus.WithError(err).Error("unable to print consent template") + return + } + data.HasConsented = res.PolicyVersion == consentCfg.Version + } + err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data) if err != nil { logrus.WithError(err).Error("unable to print consent template") + return } case http.MethodPost: + localPart, _, err := gomatrixserverlib.SplitID('@', data.User) + if err != nil { + logrus.WithError(err).Error("unable to split username") + return + } + ok, err := validHMAC(data.User, data.UserHMAC, consentCfg.FormSecret) + if err != nil || !ok { + writer.WriteHeader(http.StatusBadRequest) + _, err = writer.Write([]byte("invalid HMAC provided")) + if err != nil { + return + } + return + } + if err := userAPI.PerformUpdatePolicyVersion( + req.Context(), + &userapi.UpdatePolicyVersionRequest{ + PolicyVersion: data.Version, + LocalPart: localPart, + }, + &userapi.UpdatePolicyVersionResponse{}, + ); err != nil { + writer.WriteHeader(http.StatusInternalServerError) + _, err = writer.Write([]byte("unable to update database")) + if err != nil { + logrus.WithError(err).Error("unable to write to database") + } + return + } + // display the privacy policy without a form + data.PublicVersion = false + data.HasConsented = true + + err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data) + if err != nil { + logrus.WithError(err).Error("unable to print consent template") + return + } } } } + +func validHMAC(username, userHMAC, secret string) (bool, error) { + mac := hmac.New(sha256.New, []byte(secret)) + _, err := mac.Write([]byte(username)) + if err != nil { + return false, err + } + expectedMAC := mac.Sum(nil) + decoded, err := hex.DecodeString(userHMAC) + if err != nil { + return false, err + } + return hmac.Equal(decoded, expectedMAC), nil +} diff --git a/clientapi/routing/consent_tracking_test.go b/clientapi/routing/consent_tracking_test.go new file mode 100644 index 000000000..3ddcad778 --- /dev/null +++ b/clientapi/routing/consent_tracking_test.go @@ -0,0 +1,48 @@ +package routing + +import "testing" + +func Test_validHMAC(t *testing.T) { + type args struct { + username string + userHMAC string + secret string + } + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + name: "invalid hmac", + args: args{}, + wantErr: false, + want: false, + }, + // $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld' 27m ⚑ ◒ 15:35:54 + //(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e + // + { + name: "valid hmac", + args: args{ + username: "@alice:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + secret: "helloWorld", + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret) + if (err != nil) != tt.wantErr { + t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("validHMAC() got = %v, want %v", got, tt.want) + } + }) + } +}