Implement consent tracking

This commit is contained in:
Till Faelligen 2022-02-14 16:18:51 +01:00
parent b2045c24cb
commit 11144de92f
2 changed files with 128 additions and 6 deletions

View file

@ -1,10 +1,14 @@
package routing
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"net/http"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -25,23 +29,93 @@ func consent(userAPI userapi.UserInternalAPI, cfg *config.ClientAPI) http.Handle
_, _ = writer.Write([]byte("consent tracking is disabled"))
return
}
// The data used to populate the /consent request
data := constentTemplateData{
User: req.FormValue("u"),
Version: req.FormValue("v"),
UserHMAC: req.FormValue("h"),
}
switch req.Method {
case http.MethodGet:
// The data used to populate the /consent request
data := constentTemplateData{
User: req.FormValue("u"),
Version: req.FormValue("v"),
UserHMAC: req.FormValue("h"),
}
// display the privacy policy without a form
data.PublicVersion = data.User == "" || data.UserHMAC == "" || data.Version == ""
// let's see if the user already consented to the current version
if !data.PublicVersion {
res := &userapi.QueryPolicyVersionResponse{}
localPart, _, err := gomatrixserverlib.SplitID('@', data.User)
if err != nil {
logrus.WithError(err).Error("unable to print consent template")
return
}
if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{
LocalPart: localPart,
}, res); err != nil {
logrus.WithError(err).Error("unable to print consent template")
return
}
data.HasConsented = res.PolicyVersion == consentCfg.Version
}
err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to print consent template")
return
}
case http.MethodPost:
localPart, _, err := gomatrixserverlib.SplitID('@', data.User)
if err != nil {
logrus.WithError(err).Error("unable to split username")
return
}
ok, err := validHMAC(data.User, data.UserHMAC, consentCfg.FormSecret)
if err != nil || !ok {
writer.WriteHeader(http.StatusBadRequest)
_, err = writer.Write([]byte("invalid HMAC provided"))
if err != nil {
return
}
return
}
if err := userAPI.PerformUpdatePolicyVersion(
req.Context(),
&userapi.UpdatePolicyVersionRequest{
PolicyVersion: data.Version,
LocalPart: localPart,
},
&userapi.UpdatePolicyVersionResponse{},
); err != nil {
writer.WriteHeader(http.StatusInternalServerError)
_, err = writer.Write([]byte("unable to update database"))
if err != nil {
logrus.WithError(err).Error("unable to write to database")
}
return
}
// display the privacy policy without a form
data.PublicVersion = false
data.HasConsented = true
err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to print consent template")
return
}
}
}
}
func validHMAC(username, userHMAC, secret string) (bool, error) {
mac := hmac.New(sha256.New, []byte(secret))
_, err := mac.Write([]byte(username))
if err != nil {
return false, err
}
expectedMAC := mac.Sum(nil)
decoded, err := hex.DecodeString(userHMAC)
if err != nil {
return false, err
}
return hmac.Equal(decoded, expectedMAC), nil
}

View file

@ -0,0 +1,48 @@
package routing
import "testing"
func Test_validHMAC(t *testing.T) {
type args struct {
username string
userHMAC string
secret string
}
tests := []struct {
name string
args args
want bool
wantErr bool
}{
{
name: "invalid hmac",
args: args{},
wantErr: false,
want: false,
},
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld' 27m ⚑ ◒ 15:35:54
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
//
{
name: "valid hmac",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
secret: "helloWorld",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
if (err != nil) != tt.wantErr {
t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
}
})
}
}