Compare commits

...

82 commits

Author SHA1 Message Date
Till Faelligen cabc5f4bc9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-19 14:21:59 +02:00
Till Faelligen 7df7a966b8 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-09 17:21:37 +02:00
Till Faelligen 75ca5490bc Fix build 2022-05-09 17:21:27 +02:00
kegsay c9409078ac
Merge branch 'main' into s7evink/consent-tracking 2022-05-09 16:04:36 +01:00
Till Faelligen 3d3773d3d4 Rename migrations so they are executed 2022-05-06 10:15:42 +02:00
Till Faelligen 4bd9a73c13 Fix database locked 2022-05-06 09:50:20 +02:00
Till Faelligen 4d285fff60 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-06 09:43:50 +02:00
Till Faelligen 10dc02f1ea Debugging unit tests.. 2022-05-04 18:38:43 +02:00
Till Faelligen 94ed2d3689 Add tests, use simple http.HandlerFunc 2022-05-04 17:57:46 +02:00
Till Faelligen 964e1cef85 Remove noise 2022-05-04 17:57:29 +02:00
Till Faelligen 2eb3aab07e Split out UserConsentPolicyAPI for easier testing 2022-05-04 17:56:46 +02:00
Till Faelligen 60ba4b5612 Fix stupid mistake.. and just return the NullString 2022-05-04 17:53:49 +02:00
Till Faelligen 88612ddd0c Deduplicate constructing consent URL 2022-05-04 14:40:25 +02:00
Till Faelligen bddf8ed3ac Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-04 14:34:16 +02:00
Till Faelligen 4d5feb2544 Fix userapi issues 2022-05-04 14:33:51 +02:00
Till Faelligen cd7a7606a1 Remove noise 2022-05-04 14:12:52 +02:00
Till Faelligen dc8cea6d57 PR comments config 2022-05-04 13:47:08 +02:00
kegsay e0cdf64c33
Merge branch 'main' into s7evink/consent-tracking 2022-05-03 17:31:36 +01:00
Till Faelligen ef62255685 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-03 09:54:25 +02:00
Till Faelligen 1f64fc79c8 Merge branch 's7evink/consent-tracking' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-20 17:35:27 +02:00
Till Faelligen 2b496be2c3 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-20 17:26:27 +02:00
Till Faelligen 733b601aa9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-04 15:56:40 +02:00
Till Faelligen 2a18023a1a Linter again.. 2022-03-28 12:13:47 +02:00
Till Faelligen c99e3aff1b Fix linter issues 2022-03-28 12:01:15 +02:00
Till Faelligen f1e8d19cea Fix build 2022-03-28 11:33:12 +02:00
Till Faelligen 019f0922ea Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-28 11:30:16 +02:00
kegsay 6324c1d01f
Merge branch 'main' into s7evink/consent-tracking 2022-03-18 08:37:31 +00:00
Till Faelligen b9479a6f18 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-16 08:43:19 +01:00
Till Faelligen e42ef1706b Merge branch 's7evink/consent-tracking' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-16 08:43:12 +01:00
Till Faelligen 31ac3ac081 Use DefaultRoomVersion as roomVersion 2022-03-16 08:42:04 +01:00
kegsay 39d9d88b02
Merge branch 'main' into s7evink/consent-tracking 2022-03-09 10:05:50 +00:00
Till Faelligen 710007d600 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-07 09:46:28 +01:00
Till Faelligen dcfc0bcd43 URL Encode, use new method to get server notice room 2022-03-07 09:45:59 +01:00
Till Faelligen c7d2254698 Update templates, remove default base URL 2022-03-07 09:45:24 +01:00
Till Faelligen 7c6a162c0f Remove MemberShipStateAll 2022-03-07 09:42:02 +01:00
Till Faelligen 699617ee4d Add server_notice_room_id and methods to update/get it 2022-03-07 09:41:25 +01:00
Till Faelligen 519ea13510 Add AuthAPICheck and optional functional checks
Rename several variables
2022-03-04 17:01:18 +01:00
Till Faelligen fa26aa9138 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-04 15:10:23 +01:00
Till Faelligen e80ca307d3 Fix receipts 2022-03-04 09:59:15 +01:00
Till Faelligen df7218e230 Fix parameters 2022-03-04 09:30:46 +01:00
Till Faelligen e6e62497c9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-04 09:18:34 +01:00
Till Faelligen 2ad15f308f Rename some functions 2022-02-25 15:32:19 +01:00
Till Faelligen ed16a2f107 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-25 15:21:06 +01:00
Till Faelligen ce658ab8f2 Fix missing params 2022-02-24 08:54:54 +01:00
Till Faelligen 79e1c9e4bd Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-24 08:47:38 +01:00
Till Faelligen 2042303c6c Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-22 07:48:36 +01:00
Till Faelligen 4f2d161401 Remove consentMux 2022-02-21 17:09:25 +01:00
Till Faelligen c65eb2bf52 Fix query 2022-02-21 16:57:24 +01:00
Till Faelligen 0ae8293abd Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 16:43:45 +01:00
Till Faelligen dac29c1786 Split migrations in to two statements 2022-02-21 16:43:28 +01:00
Till Faelligen c2b6019c35 Add possibility to query all membership states 2022-02-21 16:40:59 +01:00
Till Faelligen 61cdb714df Use typed values for Consent 2022-02-21 16:23:28 +01:00
Till Faelligen 185cb7a582 Remove BaseURL from Global
Update template
2022-02-21 16:22:25 +01:00
Till Faelligen e2b0ff675b Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 14:44:52 +01:00
Till Faelligen c0845ea1ad Update logging 2022-02-21 14:44:16 +01:00
Till Faelligen 6622fda08c Add sending server notices on startup 2022-02-21 14:27:59 +01:00
Till Faelligen 219a15c4c3 Load templates into one variable 2022-02-21 14:27:13 +01:00
Till Faelligen fb95331aa2 Add posibility to track sent policy versions 2022-02-21 14:26:00 +01:00
Till Faelligen cb4526793d Add missing migrations 2022-02-21 12:15:56 +01:00
Till Faelligen 2e6987f8bd Add missing files 2022-02-21 12:12:07 +01:00
Till Faelligen 9c3a1cfd47 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 12:08:03 +01:00
Till Faelligen 74da1f0fb3 Remove "magic" Enabled function and use simple bool 2022-02-16 10:11:23 +01:00
Till Faelligen 26accb8c5d Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-16 08:39:54 +01:00
Till Faelligen 6482630f7b Fix insert statement 2022-02-15 14:28:51 +01:00
Till Faelligen 2fc1c46743 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-15 14:16:37 +01:00
Till Faelligen 5a0ec6e443 Add policy version to create-account & mediaapi 2022-02-15 14:15:18 +01:00
Till Faelligen 535d388ec0 Add new login type "m.login.terms" 2022-02-15 14:14:39 +01:00
Till Faelligen cbdbbb0839 Make sure we use the correct login stages 2022-02-15 14:13:22 +01:00
Till Faelligen f8bebe5e5a Add policy_version to insertAccount statement 2022-02-15 14:10:50 +01:00
Till Faelligen d19518fca5 Add ConsentNotGiven error
Verify consent on desired endpoints
Store consent on POST requests
2022-02-15 11:07:24 +01:00
Till Faelligen 89340cfc52 Verify the user has given their consent, otherwise block access 2022-02-14 18:11:56 +01:00
Till Faelligen 11144de92f Implement consent tracking 2022-02-14 16:18:51 +01:00
Till Faelligen b2045c24cb Add missing yaml tag 2022-02-14 16:18:19 +01:00
Till Faelligen 097f1d4609 Add a way to update the policy_version for a user 2022-02-14 15:08:00 +01:00
Till Faelligen a505471c90 Add table migrations 2022-02-14 14:52:16 +01:00
Till Faelligen 3c5c3ea7fb Add methods to query the policy version 2022-02-14 14:03:30 +01:00
Till Faelligen 9583784e8a Add new coloumn to track accepted policy version 2022-02-14 14:02:13 +01:00
Till Faelligen b6ee34918c Add consent tracking endpoint 2022-02-14 13:41:21 +01:00
Till Faelligen ac343861ad Add missing form_secret
Add tests
2022-02-14 13:06:36 +01:00
Till Faelligen 4da7df5e3e Add consent tracking template 2022-02-14 13:01:26 +01:00
Till Faelligen ccc11f94f7 Add consentAPIMux to components 2022-02-14 13:00:07 +01:00
Till Faelligen 5702b84dae Add User consent configuration
Add consentAPIMux
2022-02-14 12:59:13 +01:00
40 changed files with 1740 additions and 203 deletions

View file

@ -11,4 +11,5 @@ const (
LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token"
LoginTypeTerms = "m.login.terms"
)

View file

@ -29,6 +29,13 @@ type MatrixError struct {
Err string `json:"error"`
}
// ConsentError is an error returned to users, who didn't accept the
// TOS of this server yet.
type ConsentError struct {
MatrixError
ConsentURI string `json:"consent_uri"`
}
func (e MatrixError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrCode, e.Err)
}
@ -207,3 +214,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
}
}
// ConsentNotGiven is an error returned to users, who didn't accept the
// TOS of this server yet.
func ConsentNotGiven(consentURI string, msg string) *ConsentError {
return &ConsentError{
MatrixError: MatrixError{
ErrCode: "M_CONSENT_NOT_GIVEN",
Err: msg,
},
ConsentURI: consentURI,
}
}

View file

@ -0,0 +1,219 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// The data used to populate the /consent request
type constentTemplateData struct {
UserID string
Version string
UserHMAC string
HasConsented bool
ReadOnly bool
}
func writeHeaderAndText(w http.ResponseWriter, statusCode int) {
w.WriteHeader(statusCode)
_, _ = w.Write([]byte(http.StatusText(statusCode)))
}
func consent(writer http.ResponseWriter, req *http.Request, userAPI userapi.UserConsentPolicyAPI, cfg *config.ClientAPI) {
consentCfg := cfg.Matrix.UserConsentOptions
// The data used to populate the /consent request
data := constentTemplateData{
UserID: req.FormValue("u"),
Version: req.FormValue("v"),
UserHMAC: req.FormValue("h"),
}
switch req.Method {
case http.MethodGet:
// display the privacy policy without a form
data.ReadOnly = data.UserID == "" || data.UserHMAC == "" || data.Version == ""
// let's see if the user already consented to the current version
if !data.ReadOnly {
if ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret); err != nil || !ok {
writeHeaderAndText(writer, http.StatusForbidden)
return
}
res := &userapi.QueryPolicyVersionResponse{}
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
if err != nil {
logrus.WithError(err).Error("unable to split username")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{
Localpart: localpart,
}, res); err != nil {
logrus.WithError(err).Error("unable query policy version")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
data.HasConsented = res.PolicyVersion == consentCfg.Version
}
err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to execute consent template")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
case http.MethodPost:
ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret)
if err != nil || !ok {
if !ok {
writeHeaderAndText(writer, http.StatusForbidden)
return
}
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
if err != nil {
logrus.WithError(err).Error("unable to split username")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
if err = userAPI.PerformUpdatePolicyVersion(
req.Context(),
&userapi.UpdatePolicyVersionRequest{
PolicyVersion: data.Version,
Localpart: localpart,
},
&userapi.UpdatePolicyVersionResponse{},
); err != nil {
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
// display the privacy policy without a form
data.ReadOnly = false
data.HasConsented = true
err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to print consent template")
writeHeaderAndText(writer, http.StatusInternalServerError)
}
}
}
func sendServerNoticeForConsent(userAPI userapi.ClientUserAPI, rsAPI api.ClientRoomserverAPI,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
senderDevice *userapi.Device,
asAPI appserviceAPI.AppServiceInternalAPI,
) {
res := &userapi.QueryOutdatedPolicyResponse{}
if err := userAPI.QueryOutdatedPolicy(context.Background(), &userapi.QueryOutdatedPolicyRequest{
PolicyVersion: cfgClient.Matrix.UserConsentOptions.Version,
}, res); err != nil {
logrus.WithError(err).Error("unable to fetch users with outdated consent policy")
return
}
var (
consentOpts = cfgClient.Matrix.UserConsentOptions
data = make(map[string]string)
err error
sentMessages int
)
if len(res.UserLocalparts) == 0 {
return
}
logrus.WithField("count", len(res.UserLocalparts)).Infof("Sending server notice to users who have not yet accepted the policy")
for _, localpart := range res.UserLocalparts {
if localpart == cfgClient.Matrix.ServerNotices.LocalPart {
continue
}
userID := fmt.Sprintf("@%s:%s", localpart, cfgClient.Matrix.ServerName)
data["ConsentURL"], err = consentOpts.ConsentURL(userID)
if err != nil {
logrus.WithError(err).WithField("userID", userID).Error("unable to construct consentURI")
continue
}
msgBody := &bytes.Buffer{}
if err = consentOpts.TextTemplates.ExecuteTemplate(msgBody, "serverNoticeTemplate", data); err != nil {
logrus.WithError(err).WithField("userID", userID).Error("unable to execute serverNoticeTemplate")
continue
}
req := sendServerNoticeRequest{
UserID: userID,
Content: struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
MsgType: consentOpts.ServerNoticeContent.MsgType,
Body: msgBody.String(),
},
}
_, err = sendServerNotice(context.Background(), req, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, nil, nil, nil)
if err != nil {
logrus.WithError(err).WithField("userID", userID).Error("failed to send server notice for consent to user")
continue
}
sentMessages++
res := &userapi.UpdatePolicyVersionResponse{}
if err = userAPI.PerformUpdatePolicyVersion(context.Background(), &userapi.UpdatePolicyVersionRequest{
PolicyVersion: consentOpts.Version,
Localpart: userID,
ServerNoticeUpdate: true,
}, res); err != nil {
logrus.WithError(err).WithField("userID", userID).Error("failed to update policy version")
continue
}
}
if sentMessages > 0 {
logrus.Infof("Sent messages to %d users", sentMessages)
}
}
func validHMAC(username, userHMAC, secret string) (bool, error) {
mac := hmac.New(sha256.New, []byte(secret))
_, err := mac.Write([]byte(username))
if err != nil {
return false, err
}
expectedMAC := mac.Sum(nil)
decoded, err := hex.DecodeString(userHMAC)
if err != nil {
return false, err
}
return hmac.Equal(decoded, expectedMAC), nil
}

View file

@ -0,0 +1,236 @@
package routing
import (
"context"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
func Test_validHMAC(t *testing.T) {
type args struct {
username string
userHMAC string
secret string
}
tests := []struct {
name string
args args
want bool
wantErr bool
}{
{
name: "invalid hmac",
args: args{},
wantErr: false,
want: false,
},
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
//
{
name: "valid hmac",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
secret: "helloWorld",
},
want: true,
},
{
name: "invalid hmac",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
secret: "helloWorld",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
if (err != nil) != tt.wantErr {
t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
}
})
}
}
type dummyAPI struct {
usersConsent map[string]string
}
func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error {
return nil
}
func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error {
d.usersConsent[req.Localpart] = req.PolicyVersion
return nil
}
func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error {
res.PolicyVersion = "v2.0"
return nil
}
const dummyTemplate = `
{{ if .HasConsented }}
Consent given.
{{ else }}
WithoutForm
{{ if not .ReadOnly }}
With Form.
{{ end }}
{{ end }}`
func Test_consent(t *testing.T) {
type args struct {
username string
userHMAC string
version string
method string
}
tests := []struct {
name string
args args
wantRespCode int
wantBodyContains string
}{
{
name: "not a userID, valid hmac",
args: args{
username: "notAuserID",
userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusInternalServerError,
},
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
//
{
name: "valid hmac for alice GET, not consented",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "With form",
},
{
name: "alice consents successfully",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodPost,
},
wantRespCode: http.StatusOK,
wantBodyContains: "Consent given",
},
{
name: "valid hmac for alice GET, new version",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v2.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "With form",
},
{
name: "no hmac provided for alice, read only should be displayed",
args: args{
username: "@alice:localhost",
userHMAC: "",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "WithoutForm",
},
{
name: "alice trying to get bobs status is forbidden",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusForbidden,
wantBodyContains: "forbidden",
},
{
name: "alice trying to consent for bob is forbidden",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodPost,
},
wantRespCode: http.StatusForbidden,
wantBodyContains: "forbidden",
},
}
userAPI := dummyAPI{
usersConsent: map[string]string{},
}
consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate))
consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate))
userconsentOpts := config.UserConsentOptions{
FormSecret: "helloWorld",
Version: "v1.0",
Templates: consentTemplates,
BaseURL: "http://localhost",
}
cfg := &config.ClientAPI{
Matrix: &config.Global{
UserConsentOptions: userconsentOpts,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s",
userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC,
)
req := httptest.NewRequest(tt.args.method, url, nil)
w := httptest.NewRecorder()
consent(w, req, userAPI, cfg)
resp := w.Result()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read response body: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantRespCode {
t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode)
}
if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) {
t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body))
}
})
}
}

View file

@ -31,13 +31,12 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/tidwall/gjson"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -721,6 +720,8 @@ func handleRegistrationFlow(
}
switch r.Auth.Type {
case authtypes.LoginTypeTerms:
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms)
case authtypes.LoginTypeRecaptcha:
// Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
@ -788,11 +789,16 @@ func handleApplicationServiceRegistration(
return *err
}
policyVersion := ""
if cfg.Matrix.UserConsentOptions.Enabled {
policyVersion = cfg.Matrix.UserConsentOptions.Version
}
// If no error, application service was successfully validated.
// Don't need to worry about appending to registration stages as
// application service registration is entirely separate.
return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
)
}
@ -809,9 +815,13 @@ func checkAndCompleteFlow(
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
policyVersion := ""
if cfg.Matrix.UserConsentOptions.Enabled {
policyVersion = cfg.Matrix.UserConsentOptions.Version
}
// This flow was completed, registration can continue
return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
)
}
@ -834,7 +844,7 @@ func checkAndCompleteFlow(
func completeRegistration(
ctx context.Context,
userAPI userapi.ClientUserAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string,
username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string,
inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string,
accType userapi.AccountType,
@ -861,11 +871,12 @@ func completeRegistration(
}
var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID,
Localpart: username,
Password: password,
AccountType: accType,
OnConflict: userapi.ConflictAbort,
AppServiceID: appserviceID,
Localpart: username,
Password: password,
AccountType: accType,
OnConflict: userapi.ConflictAbort,
PolicyVersion: policyVersion,
}, &accRes)
if err != nil {
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
@ -1073,5 +1084,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType)
}

View file

@ -15,6 +15,7 @@
package routing
import (
"context"
"encoding/json"
"net/http"
@ -93,7 +94,7 @@ func PutTag(
}
tagContent.Tags[tag] = properties
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
@ -145,7 +146,7 @@ func DeleteTag(
}
}
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
@ -191,7 +192,7 @@ func obtainSavedTags(
// saveTagData saves the provided tag data into the database
func saveTagData(
req *http.Request,
context context.Context,
userID string,
roomID string,
userAPI api.ClientUserAPI,
@ -208,5 +209,5 @@ func saveTagData(
AccountData: json.RawMessage(newTagData),
}
dataRes := api.InputAccountDataResponse{}
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
return userAPI.InputAccountData(context, &dataReq, &dataRes)
}

View file

@ -127,9 +127,13 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
// server notifications
var (
serverNotificationSender *userapi.Device
err error
)
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
serverNotificationSender, err = getSenderDevice(context.Background(), userAPI, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}
@ -177,13 +181,27 @@ func Setup(
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
// NOTSPEC: consent tracking
if cfg.Matrix.UserConsentOptions.Enabled {
if !cfg.Matrix.ServerNotices.Enabled {
logrus.Warnf("Consent tracking is enabled, but server notes are not. No server notice will be sent to users")
} else {
// start a new go routine to send messages about consent
go sendServerNoticeForConsent(userAPI, rsAPI, &cfg.Matrix.ServerNotices, cfg, serverNotificationSender, asAPI)
}
publicAPIMux.HandleFunc("/consent", func(writer http.ResponseWriter, request *http.Request) {
consent(writer, request, userAPI, cfg)
}).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
}
consentRequiredCheck := httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)
v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -213,7 +231,7 @@ func Setup(
return PeekRoomByIDOrAlias(
req, device, rsAPI, vars["roomIDOrAlias"],
)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
}
v3mux.Handle("/joined_rooms",
@ -258,7 +276,7 @@ func Setup(
return UnpeekRoomByID(
req, device, rsAPI, vars["roomID"],
)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -267,7 +285,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -279,7 +297,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -288,7 +306,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -297,7 +315,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -317,7 +335,7 @@ func Setup(
txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
nil, cfg, rsAPI, transactionsCache)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -326,7 +344,7 @@ func Setup(
return util.ErrorResponse(err)
}
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -335,7 +353,7 @@ func Setup(
return util.ErrorResponse(err)
}
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -343,7 +361,7 @@ func Setup(
return util.ErrorResponse(err)
}
return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -354,7 +372,7 @@ func Setup(
eventType := strings.TrimSuffix(vars["type"], "/")
eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -363,7 +381,7 @@ func Setup(
}
eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -374,7 +392,7 @@ func Setup(
emptyString := ""
eventType := strings.TrimSuffix(vars["eventType"], "/")
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
@ -385,7 +403,7 @@ func Setup(
}
stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
@ -487,7 +505,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -497,7 +515,7 @@ func Setup(
}
txnID := vars["txnId"]
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
@ -508,7 +526,7 @@ func Setup(
}
txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
// This is only here because sytest refers to /unstable for this endpoint
@ -522,7 +540,7 @@ func Setup(
}
txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/account/whoami",
@ -531,7 +549,7 @@ func Setup(
return *r
}
return Whoami(req, device)
}),
}, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password",
@ -738,7 +756,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
@ -763,7 +781,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
@ -771,19 +789,19 @@ func Setup(
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, userAPI, device)
}),
}, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/account/3pid/delete",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Forget3PID(req, userAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
@ -798,7 +816,7 @@ func Setup(
return *r
}
return RequestTurnServer(req, device, cfg)
}),
}, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/protocols",
@ -868,7 +886,7 @@ func Setup(
return util.ErrorResponse(err)
}
return GetAdminWhois(req, userAPI, device, vars["userID"])
}),
}, consentRequiredCheck),
).Methods(http.MethodGet)
v3mux.Handle("/user/{userID}/openid/request_token",
@ -881,7 +899,7 @@ func Setup(
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/user_directory/search",
@ -907,7 +925,7 @@ func Setup(
postContent.SearchString,
postContent.Limit,
)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/members",
@ -953,7 +971,7 @@ func Setup(
return util.ErrorResponse(err)
}
return SendForget(req, device, vars["roomID"], rsAPI)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/upgrade",
@ -1065,7 +1083,7 @@ func Setup(
return util.ErrorResponse(err)
}
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
}, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
@ -1075,7 +1093,7 @@ func Setup(
return util.ErrorResponse(err)
}
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
}, consentRequiredCheck),
).Methods(http.MethodDelete, http.MethodOptions)
v3mux.Handle("/capabilities",
@ -1095,11 +1113,11 @@ func Setup(
return util.ErrorResponse(err)
}
return KeyBackupVersion(req, userAPI, device, vars["version"])
})
}, consentRequiredCheck)
getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return KeyBackupVersion(req, userAPI, device, "")
})
}, consentRequiredCheck)
putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1107,7 +1125,7 @@ func Setup(
return util.ErrorResponse(err)
}
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
})
}, consentRequiredCheck)
deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1119,7 +1137,7 @@ func Setup(
postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateKeyBackupVersion(req, userAPI, device)
})
}, consentRequiredCheck)
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1150,7 +1168,7 @@ func Setup(
return *resErr
}
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
})
}, consentRequiredCheck)
// Single room bulk session
putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1182,7 +1200,7 @@ func Setup(
}
reqBody.Rooms[roomID] = body
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
})
}, consentRequiredCheck)
// Single room, single session
putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1215,7 +1233,7 @@ func Setup(
}
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
})
}, consentRequiredCheck)
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1229,7 +1247,7 @@ func Setup(
getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
})
}, consentRequiredCheck)
getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1245,7 +1263,7 @@ func Setup(
return util.ErrorResponse(err)
}
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
})
}, consentRequiredCheck)
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1261,11 +1279,11 @@ func Setup(
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
})
}, consentRequiredCheck)
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
})
}, consentRequiredCheck)
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
@ -1277,12 +1295,12 @@ func Setup(
v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1305,7 +1323,7 @@ func Setup(
}
return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
}),
}, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/presence/{userId}/status",
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -84,74 +84,66 @@ func SendServerNotice(
if resErr != nil {
return *resErr
}
res, _ := sendServerNotice(ctx, r, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, txnID, device, txnCache)
return res
}
func sendServerNotice(
ctx context.Context,
serverNoticeRequest sendServerNoticeRequest,
rsAPI api.ClientRoomserverAPI,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
senderDevice *userapi.Device,
asAPI appserviceAPI.AppServiceInternalAPI,
userAPI userapi.ClientUserAPI,
txnID *string,
device *userapi.Device,
txnCache *transactions.Cache,
) (util.JSONResponse, error) {
// check that all required fields are set
if !r.valid() {
if !serverNoticeRequest.valid() {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid request"),
}
}, fmt.Errorf("Invalid JSON")
}
// get rooms for specified user
allUserRooms := []string{}
userRooms := api.QueryRoomsForUserResponse{}
// Get rooms the user is either joined, invited or has left.
for _, membership := range []string{"join", "invite", "leave"} {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: membership,
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
qryServerNoticeRoom := &userapi.QueryServerNoticeRoomResponse{}
localpart, _, err := gomatrixserverlib.SplitID('@', serverNoticeRequest.UserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid request"),
}, err
}
err = userAPI.SelectServerNoticeRoomID(ctx, &userapi.QueryServerNoticeRoomRequest{Localpart: localpart}, qryServerNoticeRoom)
if err != nil {
return util.ErrorResponse(err), err
}
// get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
senderRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: senderUserID,
WantMembership: "join",
}, &senderRooms); err != nil {
return util.ErrorResponse(err)
}
// check if we have rooms in common
commonRooms := []string{}
for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs {
if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID)
}
}
}
if len(commonRooms) > 1 {
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
}
var (
roomID string
roomVersion = version.DefaultRoomVersion()
)
roomID := qryServerNoticeRoom.RoomID
roomVersion := version.DefaultRoomVersion()
// create a new room for the user
if len(commonRooms) == 0 {
if qryServerNoticeRoom.RoomID == "" {
var pl, cc []byte
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent)
powerLevelContent.Users[serverNoticeRequest.UserID] = -10 // taken from Synapse
pl, err = json.Marshal(powerLevelContent)
if err != nil {
return util.ErrorResponse(err)
return util.ErrorResponse(err), err
}
createContent := map[string]interface{}{}
createContent["m.federate"] = false
cc, err := json.Marshal(createContent)
cc, err = json.Marshal(createContent)
if err != nil {
return util.ErrorResponse(err)
return util.ErrorResponse(err), err
}
crReq := createRoomRequest{
Invite: []string{r.UserID},
Invite: []string{serverNoticeRequest.UserID},
Name: cfgNotices.RoomName,
Visibility: "private",
Preset: presetPrivateChat,
@ -166,36 +158,40 @@ func SendServerNotice(
switch data := roomRes.JSON.(type) {
case createRoomResponse:
roomID = data.RoomID
res := &userapi.UpdateServerNoticeRoomResponse{}
err = userAPI.UpdateServerNoticeRoomID(ctx, &userapi.UpdateServerNoticeRoomRequest{RoomID: roomID, Localpart: localpart}, res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("UpdateServerNoticeRoomID failed")
return jsonerror.InternalServerError(), err
}
// tag the room, so we can later check if the user tries to reject an invite
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
"m.server_notice": {
Order: 1.0,
},
}}
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil {
if err = saveTagData(ctx, serverNoticeRequest.UserID, roomID, userAPI, serverAlertTag); err != nil {
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
return jsonerror.InternalServerError(), err
}
default:
// if we didn't get a createRoomResponse, we probably received an error, so return that.
return roomRes
return roomRes, fmt.Errorf("Unable to create room")
}
} else {
// we've found a room in common, check the membership
roomID = commonRooms[0]
membershipRes := api.QueryMembershipForUserResponse{}
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
res := &api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: serverNoticeRequest.UserID, RoomID: roomID}, res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
return jsonerror.InternalServerError()
return util.ErrorResponse(err), err
}
if !membershipRes.IsInRoom {
// re-invite the user
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
// re-invite the user
if res.Membership != gomatrixserverlib.Join {
var inviteRes util.JSONResponse
inviteRes, err = sendInvite(ctx, userAPI, senderDevice, roomID, serverNoticeRequest.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
if err != nil {
return res
return inviteRes, err
}
}
}
@ -203,13 +199,13 @@ func SendServerNotice(
startedGeneratingEvent := time.Now()
request := map[string]interface{}{
"body": r.Content.Body,
"msgtype": r.Content.MsgType,
"body": serverNoticeRequest.Content.Body,
"msgtype": serverNoticeRequest.Content.MsgType,
}
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
if resErr != nil {
logrus.Errorf("failed to send message: %+v", resErr)
return *resErr
return *resErr, fmt.Errorf("Unable to send event")
}
timeToGenerateEvent := time.Since(startedGeneratingEvent)
@ -224,7 +220,7 @@ func SendServerNotice(
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents(
if err = api.SendEvents(
ctx, rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
@ -236,7 +232,7 @@ func SendServerNotice(
false,
); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
return jsonerror.InternalServerError(), err
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": e.EventID(),
@ -259,7 +255,7 @@ func SendServerNotice(
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res
return res, nil
}
func (r sendServerNoticeRequest) valid() (ok bool) {

View file

@ -146,7 +146,12 @@ func main() {
logrus.Fatalln("Username is already in use.")
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
policyVersion := ""
if cfg.Global.UserConsentOptions.Enabled {
policyVersion = cfg.Global.UserConsentOptions.Version
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error())
}

View file

@ -71,6 +71,35 @@ global:
# appear in user clients.
room_name: "Server Alerts"
# Consent tracking configuration
user_consent:
# If the user consent tracking is enabled or not
enabled: false
# The base URL this homeserver will serve clients on, e.g. https://matrix.org
base_url: http://localhost
# Randomly generated string (e.g. by using "pwgen -sy 32") to be used to calculate the HMAC
form_secret: "superSecretRandomlyGeneratedSecret"
# Require consent when user registers for the first time
require_at_registration: false
# The name to be shown to the user
policy_name: "Privacy policy"
# The directory to search for templates
template_dir: "./templates/privacy"
# The version of the policy. When loading templates, ".gohtml" template is added as a suffix
# e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to "1.0"
version: "1.0"
# Send a consent message to guest users
send_server_notice_to_guest: false
# Default message to send to users
server_notice_content:
msg_type: "m.text"
body: >-
Please give your consent to the privacy policy at {{ .ConsentURL }}.
# The error message to display if the user hasn't given their consent yet
block_events_error: >-
You can't send any messages until you consent to the privacy policy at
{{ .ConsentURL }}.
# Configuration for NATS JetStream
jetstream:
# A list of NATS Server addresses to connect to. If none are specified, an

26
docs/templates/privacy/1.0.gohtml vendored Normal file
View file

@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<title>Privacy policy</title>
</head>
<body>
{{ if .HasConsented }}
<p>
You have already given your consent.
</p>
{{ else }}
<p>
Please give your consent to keep using this homeserver.
</p>
{{ if not .ReadOnly }}
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
<form method="post" action="consent">
<input type="hidden" name="v" value="{{ .Version }}"/>
<input type="hidden" name="u" value="{{ .UserID }}"/>
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
<input type="submit" value="I consent"/>
</form>
{{ end }}
{{ end }}
</body>
</html>

View file

@ -15,6 +15,8 @@
package httputil
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
@ -25,9 +27,12 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@ -41,10 +46,24 @@ type BasicAuth struct {
Password string `yaml:"password"`
}
// AuthAPICheck is an option to MakeAuthAPI to add additional checks (e.g. WithConsentCheck) to verify
// the user is allowed to do specific things.
type AuthAPICheck func(ctx context.Context, device *userapi.Device) *util.JSONResponse
// WithConsentCheck checks that a user has given his consent.
func WithConsentCheck(options config.UserConsentOptions, api userapi.QueryPolicyVersionAPI) AuthAPICheck {
return func(ctx context.Context, device *userapi.Device) *util.JSONResponse {
if !options.Enabled {
return nil
}
return checkConsent(ctx, device.UserID, api, options)
}
}
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse,
f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPICheck,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context())
@ -72,6 +91,14 @@ func MakeAuthAPI(
}
}()
// apply additional checks, if any
for _, opt := range checks {
resp := opt(req.Context(), device)
if resp != nil {
return *resp
}
}
jsonRes := f(req, device)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
@ -83,6 +110,53 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h)
}
func checkConsent(ctx context.Context, userID string, userAPI userapi.QueryPolicyVersionAPI, userConsentCfg config.UserConsentOptions) *util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil
}
// check which version of the policy the user accepted
res := &userapi.QueryPolicyVersionResponse{}
err = userAPI.QueryPolicyVersion(ctx, &userapi.QueryPolicyVersionRequest{
Localpart: localpart,
}, res)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to get policy version"),
}
}
// user hasn't accepted any policy, block access.
if userConsentCfg.Version != res.PolicyVersion {
uri, err := userConsentCfg.ConsentURL(userID)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to get consent URL"),
}
}
msg := &bytes.Buffer{}
c := struct {
ConsentURL string
}{
ConsentURL: uri,
}
if err = userConsentCfg.TextTemplates.ExecuteTemplate(msg, "blockEventsError", c); err != nil {
logrus.Infof("error consent message: %+v", err)
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to execute template"),
}
}
return &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.ConsentNotGiven(uri, msg.String()),
}
}
return nil
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {

View file

@ -59,15 +59,12 @@ func Setup(
PathToResult: map[string]*types.ThumbnailGenerationResult{},
}
uploadHandler := httputil.MakeAuthAPI(
"upload", userAPI,
func(req *http.Request, dev *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
},
)
uploadHandler := httputil.MakeAuthAPI("upload", userAPI, func(req *http.Request, dev *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
})
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {

View file

@ -268,6 +268,7 @@ func (b *BaseDendrite) Close() error {
func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) {
if dbProperties.ConnectionString != "" || b == nil {
// Open a new database connection using the supplied config.
logrus.Infof("Open a new database connection using the supplied config.: %+v", dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties, writer)
return db, writer, err
}

View file

@ -265,6 +265,21 @@ func loadConfig(
return &c, nil
}
type Terms struct {
Policies Policies `json:"policies"`
}
type En struct {
Name string `json:"name"`
URL string `json:"url"`
}
type PrivacyPolicy struct {
En En `json:"en"`
Version string `json:"version"`
}
type Policies struct {
PrivacyPolicy PrivacyPolicy `json:"privacy_policy"`
}
// Derive generates data that is derived from various values provided in
// the config file.
func (config *Dendrite) Derive() error {
@ -275,13 +290,39 @@ func (config *Dendrite) Derive() error {
// TODO: Add email auth type
// TODO: Add MSISDN auth type
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
uri := config.Global.UserConsentOptions.BaseURL + "/_matrix/client/consent?v=" + config.Global.UserConsentOptions.Version
config.Derived.Registration.Params[authtypes.LoginTypeTerms] = Terms{
Policies: Policies{
PrivacyPolicy: PrivacyPolicy{
En: En{
Name: config.Global.UserConsentOptions.PolicyName,
URL: uri,
},
Version: config.Global.UserConsentOptions.Version,
},
},
}
}
if config.ClientAPI.RecaptchaEnabled {
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
} else {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
}
if config.Derived.Registration.Flows == nil {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{
Stages: []authtypes.LoginType{authtypes.LoginTypeDummy},
})
}
// prepend each flow with LoginTypeTerms or LoginTypeRecaptcha
for i, flow := range config.Derived.Registration.Flows {
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeTerms}, flow.Stages...)
}
if config.ClientAPI.RecaptchaEnabled {
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeRecaptcha}, flow.Stages...)
}
config.Derived.Registration.Flows[i] = flow
}
// Load application service configuration files

View file

@ -1,7 +1,15 @@
package config
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"html/template"
"math/rand"
"net/url"
"path/filepath"
textTemplate "text/template"
"time"
"github.com/matrix-org/gomatrixserverlib"
@ -71,6 +79,9 @@ type Global struct {
// ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"`
// Consent tracking options
UserConsentOptions UserConsentOptions `yaml:"user_consent"`
// ReportStats configures opt-in anonymous stats reporting.
ReportStats ReportStats `yaml:"report_stats"`
}
@ -88,6 +99,7 @@ func (c *Global) Defaults(generate bool) {
c.Metrics.Defaults(generate)
c.DNSCache.Defaults()
c.Sentry.Defaults()
c.UserConsentOptions.Defaults()
c.ServerNotices.Defaults(generate)
c.ReportStats.Defaults()
}
@ -100,6 +112,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Metrics.Verify(configErrs, isMonolith)
c.Sentry.Verify(configErrs, isMonolith)
c.DNSCache.Verify(configErrs, isMonolith)
c.UserConsentOptions.Verify(configErrs, isMonolith)
c.ServerNotices.Verify(configErrs, isMonolith)
c.ReportStats.Verify(configErrs, isMonolith)
}
@ -261,6 +274,102 @@ func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime))
}
// Consent tracking configuration
// If either require_at_registration or send_server_notice_to_guest are true, consent
// messages will be sent to the users.
type UserConsentOptions struct {
// If consent tracking is enabled or not
Enabled bool `yaml:"enabled"`
// Randomly generated string to be used to calculate the HMAC
FormSecret string `yaml:"form_secret"`
// Require consent when user registers for the first time
RequireAtRegistration bool `yaml:"require_at_registration"`
// The name to be shown to the user
PolicyName string `yaml:"policy_name"`
// The directory to search for *.gohtml templates
TemplateDir string `yaml:"template_dir"`
// The version of the policy. When loading templates, ".gohtml" template is added as a suffix
// e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to 1.0
Version string `yaml:"version"`
// Send a consent message to guest users
SendServerNoticeToGuest bool `yaml:"send_server_notice_to_guest"`
// Default message to send to users
ServerNoticeContent struct {
MsgType string `yaml:"msg_type"`
Body string `yaml:"body"`
} `yaml:"server_notice_content"`
// The error message to display if the user hasn't given their consent yet
BlockEventsError string `yaml:"block_events_error"`
// All loaded templates
Templates *template.Template `yaml:"-"`
TextTemplates *textTemplate.Template `yaml:"-"`
// The base URL this homeserver will serve clients on, e.g. https://matrix.org
BaseURL string `yaml:"base_url"`
}
func (c *UserConsentOptions) Defaults() {
c.Enabled = false
c.RequireAtRegistration = false
c.SendServerNoticeToGuest = false
c.PolicyName = "Privacy Policy"
c.Version = "1.0"
c.TemplateDir = "./templates/privacy"
}
func (c *UserConsentOptions) Verify(configErrors *ConfigErrors, isMonolith bool) {
if !c.Enabled {
return
}
checkNotEmpty(configErrors, "template_dir", c.TemplateDir)
checkNotEmpty(configErrors, "version", c.Version)
checkNotEmpty(configErrors, "policy_name", c.PolicyName)
checkNotEmpty(configErrors, "form_secret", c.FormSecret)
checkNotEmpty(configErrors, "base_url", c.BaseURL)
if len(*configErrors) > 0 {
return
}
p, err := filepath.Abs(c.TemplateDir)
if err != nil {
configErrors.Add("unable to get template directory")
return
}
c.TextTemplates = textTemplate.Must(textTemplate.New("blockEventsError").Parse(c.BlockEventsError))
c.TextTemplates = textTemplate.Must(c.TextTemplates.New("serverNoticeTemplate").Parse(c.ServerNoticeContent.Body))
// Read all defined *.gohtml templates
t, err := template.ParseGlob(filepath.Join(p, "*.gohtml"))
if err != nil || t == nil {
configErrors.Add(fmt.Sprintf("unable to read consent templates: %+v", err))
return
}
c.Templates = t
// Verify we've got a template for the defined version
versionTemplate := c.Templates.Lookup(c.Version + ".gohtml")
if versionTemplate == nil {
configErrors.Add(fmt.Sprintf("unable to load defined '%s' policy template", c.Version))
}
}
// ConsentURL constructs the URL shown to users to accept the TOS
func (c *UserConsentOptions) ConsentURL(userID string) (string, error) {
mac := hmac.New(sha256.New, []byte(c.FormSecret))
_, err := mac.Write([]byte(userID))
if err != nil {
return "", err
}
userMAC := hex.EncodeToString(mac.Sum(nil))
params := url.Values{}
params.Add("u", userID)
params.Add("h", userMAC)
params.Add("v", c.Version)
return fmt.Sprintf("%s/_matrix/client/consent?%s", c.BaseURL, params.Encode()), nil
}
// PresenceOptions defines possible configurations for presence events.
type PresenceOptions struct {
// Whether inbound presence events are allowed

View file

@ -0,0 +1,115 @@
package config
import (
"testing"
)
func TestUserConsentOptions_Verify(t *testing.T) {
type args struct {
configErrors *ConfigErrors
isMonolith bool
}
tests := []struct {
name string
fields UserConsentOptions
args args
wantErr bool
}{
{
name: "template dir not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "template dir set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "policy name not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "policy name set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
PolicyName: "Privacy policy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "version not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "everyhing required set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "./testdata/privacy",
Version: "1.0",
PolicyName: "Privacy policy",
FormSecret: "helloWorld",
BaseURL: "http://localhost",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &UserConsentOptions{
Enabled: true,
BaseURL: tt.fields.BaseURL,
FormSecret: tt.fields.FormSecret,
RequireAtRegistration: tt.fields.RequireAtRegistration,
PolicyName: tt.fields.PolicyName,
Version: tt.fields.Version,
TemplateDir: tt.fields.TemplateDir,
SendServerNoticeToGuest: tt.fields.SendServerNoticeToGuest,
ServerNoticeContent: tt.fields.ServerNoticeContent,
BlockEventsError: tt.fields.BlockEventsError,
}
c.Verify(tt.args.configErrors, tt.args.isMonolith)
if !tt.wantErr && len(*tt.args.configErrors) > 0 {
t.Errorf("expected no errors, got '%+v'", tt.args.configErrors)
}
})
}
}

View file

@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<title>Privacy policy</title>
</head>
<body>
{{ if .HasConsented }}
<p>
You have already given your consent.
</p>
{{ else }}
<p>
Please give your consent to keep using this homeserver.
</p>
{{ if not .ReadOnly }}
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
<form method="post" action="consent">
<input type="hidden" name="v" value="{{ .Version }}"/>
<input type="hidden" name="u" value="{{ .UserID }}"/>
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
<input type="submit" value="I consent"/>
</form>
{{ end }}
{{ end }}
</body>
</html>

View file

@ -44,7 +44,6 @@ func ParseFlags(monolith bool) *config.Dendrite {
}
cfg, err := config.Load(*configPath, monolith)
if err != nil {
logrus.Fatalf("Invalid config file: %s", err)
}

View file

@ -93,6 +93,6 @@ func Setup(
vars["roomId"], vars["eventId"],
lazyLoadCache,
)
}),
}, httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)),
).Methods(http.MethodGet, http.MethodOptions)
}

View file

@ -66,6 +66,7 @@ type FederationUserAPI interface {
// api functions required by the sync api
type SyncUserAPI interface {
QueryAcccessTokenAPI
QueryPolicyVersionAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -78,6 +79,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
UserConsentPolicyAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -106,6 +108,18 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
}
type UserConsentPolicyAPI interface {
QueryPolicyVersionAPI
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
}
type QueryPolicyVersionAPI interface {
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
}
// custom api functions required by pinecone / p2p demos
@ -318,12 +332,12 @@ type QuerySearchProfilesResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account
OnConflict Conflict
AccountType AccountType // Required: whether this is a guest or user account
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
PolicyVersion string // optional: the privacy policy this account has accepted
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account
OnConflict Conflict
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
@ -412,6 +426,53 @@ type QueryOpenIDTokenResponse struct {
ExpiresAtMS int64
}
// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
type QueryPolicyVersionRequest struct {
Localpart string
}
// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
type QueryPolicyVersionResponse struct {
PolicyVersion string
}
// QueryOutdatedPolicyRequest is the request for QueryOutdatedPolicyRequest
type QueryOutdatedPolicyRequest struct {
PolicyVersion string
}
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
type QueryOutdatedPolicyResponse struct {
UserLocalparts []string
}
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
type UpdatePolicyVersionRequest struct {
PolicyVersion, Localpart string
ServerNoticeUpdate bool
}
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
type UpdatePolicyVersionResponse struct{}
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomRequest struct {
Localpart string
}
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomResponse struct {
RoomID string
}
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomRequest struct {
Localpart, RoomID string
}
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomResponse struct{}
// Device represents a client's device (mobile, web, etc)
type Device struct {
ID string

View file

@ -203,6 +203,36 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err
}
func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error {
err := t.Impl.QueryPolicyVersion(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error {
err := t.Impl.QueryOutdatedPolicy(ctx, req, res)
util.GetLogger(ctx).Infof("QueryOutdatedPolicy req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {

View file

@ -66,7 +66,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@ -833,3 +833,60 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
}
const pushRulesAccountDataType = "m.push_rules"
func (a *UserInternalAPI) QueryPolicyVersion(
ctx context.Context,
req *api.QueryPolicyVersionRequest,
res *api.QueryPolicyVersionResponse,
) error {
var err error
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
if err != nil {
return err
}
return nil
}
func (a *UserInternalAPI) QueryOutdatedPolicy(
ctx context.Context,
req *api.QueryOutdatedPolicyRequest,
res *api.QueryOutdatedPolicyResponse,
) error {
var err error
res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
if err != nil {
return err
}
return nil
}
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
ctx context.Context,
req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse,
) error {
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
}
func (a *UserInternalAPI) SelectServerNoticeRoomID(
ctx context.Context,
req *api.QueryServerNoticeRoomRequest,
res *api.QueryServerNoticeRoomResponse,
) (err error) {
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
if err != nil {
return err
}
res.RoomID = roomID
return nil
}
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
ctx context.Context,
req *api.UpdateServerNoticeRoomRequest,
res *api.UpdateServerNoticeRoomResponse,
) (err error) {
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
}

View file

@ -44,6 +44,8 @@ const (
PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile"
@ -61,6 +63,9 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -391,3 +396,43 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryOutdatedPolicy(ctx context.Context, req *api.QueryOutdatedPolicyRequest, res *api.QueryOutdatedPolicyResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOutdatedPolicy")
defer span.Finish()
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUpdatePolicyVersion")
defer span.Finish()
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + QueryServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion")
defer span.Finish()
apiURL := h.apiURL + QueryPolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -457,4 +457,74 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
internalAPIMux.Handle(QueryPolicyVersionPath,
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
request := api.QueryPolicyVersionRequest{}
response := api.QueryPolicyVersionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.QueryPolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryOutdatedPolicyRequest{}
response := api.QueryOutdatedPolicyResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.QueryOutdatedPolicy(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
request := api.UpdatePolicyVersionRequest{}
response := api.UpdatePolicyVersionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryServerNoticeRoomPath,
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerNoticeRoomRequest{}
response := api.QueryServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.UpdateServerNoticeRoomRequest{}
response := api.UpdateServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -36,7 +36,7 @@ type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
@ -126,9 +126,18 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
type ConsentTracking interface {
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
}
type Database interface {
Account
AccountData
ConsentTracking
Device
KeyBackup
LoginToken

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL
account_type SMALLINT NOT NULL,
-- The policy version this user has accepted
policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -67,14 +73,38 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
@ -92,6 +122,12 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db)
}
@ -99,16 +135,16 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
}
if err != nil {
return nil, err
@ -178,3 +214,71 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id)
return id + 1, err
}
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
var policyNull sql.NullString
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
return policyNull.String, err
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) {
stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -12,6 +12,7 @@ import (
func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -0,0 +1,45 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -46,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -125,7 +126,7 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
@ -139,7 +140,7 @@ func (d *Database) CreateAccount(
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType)
return err
})
return
@ -148,7 +149,7 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) {
var err error
var account *api.Account
@ -160,7 +161,8 @@ func (d *Database) createAccount(
return nil, err
}
}
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil {
logrus.WithError(err).Error("d.Accounts.InsertAccount error")
return nil, sqlutil.ErrUserExists
}
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
@ -763,3 +765,42 @@ func (d *Database) RemovePushers(
func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) {
return d.Stats.UserStatistics(ctx, nil)
}
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
})
return
}
// SelectServerNoticeRoomID returns the server notice room, if one is set.
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
}
// UpdateServerNoticeRoomID updates the server notice room
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
})
return
}

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type INTEGER NOT NULL
account_type INTEGER NOT NULL,
-- The policy version this user has accepted
policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -67,15 +73,39 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
@ -94,6 +124,12 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db)
}
@ -101,16 +137,16 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
}
if err != nil {
return nil, err
@ -183,3 +219,72 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
}
return id + 1, err
}
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
var policyNull sql.NullString
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
return policyNull.String, err
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) {
stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -12,6 +12,7 @@ import (
func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -4,15 +4,9 @@ import (
"database/sql"
"fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func init() {
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}

View file

@ -0,0 +1,44 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -18,11 +18,10 @@ import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
@ -47,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}

View file

@ -78,7 +78,7 @@ func Test_Accounts(t *testing.T) {
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account
@ -102,7 +102,7 @@ func Test_Accounts(t *testing.T) {
first, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
_, err = db.CreateAccount(ctx, "", "testing", "", "v1.0", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err)
@ -350,7 +350,7 @@ func Test_Profile(t *testing.T) {
defer close()
// create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)

View file

@ -32,12 +32,18 @@ type AccountDataTable interface {
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
}
type DevicesTable interface {

View file

@ -88,7 +88,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16)
}
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
_, err := accDB.InsertAccount(ctx, nil, localpart, "", "", appServiceID, accType)
if err != nil {
t.Fatalf("unable to create account: %v", err)
}

View file

@ -75,7 +75,7 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
@ -154,7 +154,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}