Compare commits

...

82 commits

Author SHA1 Message Date
Till Faelligen cabc5f4bc9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-19 14:21:59 +02:00
Till Faelligen 7df7a966b8 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-09 17:21:37 +02:00
Till Faelligen 75ca5490bc Fix build 2022-05-09 17:21:27 +02:00
kegsay c9409078ac
Merge branch 'main' into s7evink/consent-tracking 2022-05-09 16:04:36 +01:00
Till Faelligen 3d3773d3d4 Rename migrations so they are executed 2022-05-06 10:15:42 +02:00
Till Faelligen 4bd9a73c13 Fix database locked 2022-05-06 09:50:20 +02:00
Till Faelligen 4d285fff60 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-06 09:43:50 +02:00
Till Faelligen 10dc02f1ea Debugging unit tests.. 2022-05-04 18:38:43 +02:00
Till Faelligen 94ed2d3689 Add tests, use simple http.HandlerFunc 2022-05-04 17:57:46 +02:00
Till Faelligen 964e1cef85 Remove noise 2022-05-04 17:57:29 +02:00
Till Faelligen 2eb3aab07e Split out UserConsentPolicyAPI for easier testing 2022-05-04 17:56:46 +02:00
Till Faelligen 60ba4b5612 Fix stupid mistake.. and just return the NullString 2022-05-04 17:53:49 +02:00
Till Faelligen 88612ddd0c Deduplicate constructing consent URL 2022-05-04 14:40:25 +02:00
Till Faelligen bddf8ed3ac Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-04 14:34:16 +02:00
Till Faelligen 4d5feb2544 Fix userapi issues 2022-05-04 14:33:51 +02:00
Till Faelligen cd7a7606a1 Remove noise 2022-05-04 14:12:52 +02:00
Till Faelligen dc8cea6d57 PR comments config 2022-05-04 13:47:08 +02:00
kegsay e0cdf64c33
Merge branch 'main' into s7evink/consent-tracking 2022-05-03 17:31:36 +01:00
Till Faelligen ef62255685 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-05-03 09:54:25 +02:00
Till Faelligen 1f64fc79c8 Merge branch 's7evink/consent-tracking' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-20 17:35:27 +02:00
Till Faelligen 2b496be2c3 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-20 17:26:27 +02:00
Till Faelligen 733b601aa9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-04-04 15:56:40 +02:00
Till Faelligen 2a18023a1a Linter again.. 2022-03-28 12:13:47 +02:00
Till Faelligen c99e3aff1b Fix linter issues 2022-03-28 12:01:15 +02:00
Till Faelligen f1e8d19cea Fix build 2022-03-28 11:33:12 +02:00
Till Faelligen 019f0922ea Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-28 11:30:16 +02:00
kegsay 6324c1d01f
Merge branch 'main' into s7evink/consent-tracking 2022-03-18 08:37:31 +00:00
Till Faelligen b9479a6f18 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-16 08:43:19 +01:00
Till Faelligen e42ef1706b Merge branch 's7evink/consent-tracking' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-16 08:43:12 +01:00
Till Faelligen 31ac3ac081 Use DefaultRoomVersion as roomVersion 2022-03-16 08:42:04 +01:00
kegsay 39d9d88b02
Merge branch 'main' into s7evink/consent-tracking 2022-03-09 10:05:50 +00:00
Till Faelligen 710007d600 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-07 09:46:28 +01:00
Till Faelligen dcfc0bcd43 URL Encode, use new method to get server notice room 2022-03-07 09:45:59 +01:00
Till Faelligen c7d2254698 Update templates, remove default base URL 2022-03-07 09:45:24 +01:00
Till Faelligen 7c6a162c0f Remove MemberShipStateAll 2022-03-07 09:42:02 +01:00
Till Faelligen 699617ee4d Add server_notice_room_id and methods to update/get it 2022-03-07 09:41:25 +01:00
Till Faelligen 519ea13510 Add AuthAPICheck and optional functional checks
Rename several variables
2022-03-04 17:01:18 +01:00
Till Faelligen fa26aa9138 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-04 15:10:23 +01:00
Till Faelligen e80ca307d3 Fix receipts 2022-03-04 09:59:15 +01:00
Till Faelligen df7218e230 Fix parameters 2022-03-04 09:30:46 +01:00
Till Faelligen e6e62497c9 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-03-04 09:18:34 +01:00
Till Faelligen 2ad15f308f Rename some functions 2022-02-25 15:32:19 +01:00
Till Faelligen ed16a2f107 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-25 15:21:06 +01:00
Till Faelligen ce658ab8f2 Fix missing params 2022-02-24 08:54:54 +01:00
Till Faelligen 79e1c9e4bd Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-24 08:47:38 +01:00
Till Faelligen 2042303c6c Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-22 07:48:36 +01:00
Till Faelligen 4f2d161401 Remove consentMux 2022-02-21 17:09:25 +01:00
Till Faelligen c65eb2bf52 Fix query 2022-02-21 16:57:24 +01:00
Till Faelligen 0ae8293abd Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 16:43:45 +01:00
Till Faelligen dac29c1786 Split migrations in to two statements 2022-02-21 16:43:28 +01:00
Till Faelligen c2b6019c35 Add possibility to query all membership states 2022-02-21 16:40:59 +01:00
Till Faelligen 61cdb714df Use typed values for Consent 2022-02-21 16:23:28 +01:00
Till Faelligen 185cb7a582 Remove BaseURL from Global
Update template
2022-02-21 16:22:25 +01:00
Till Faelligen e2b0ff675b Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 14:44:52 +01:00
Till Faelligen c0845ea1ad Update logging 2022-02-21 14:44:16 +01:00
Till Faelligen 6622fda08c Add sending server notices on startup 2022-02-21 14:27:59 +01:00
Till Faelligen 219a15c4c3 Load templates into one variable 2022-02-21 14:27:13 +01:00
Till Faelligen fb95331aa2 Add posibility to track sent policy versions 2022-02-21 14:26:00 +01:00
Till Faelligen cb4526793d Add missing migrations 2022-02-21 12:15:56 +01:00
Till Faelligen 2e6987f8bd Add missing files 2022-02-21 12:12:07 +01:00
Till Faelligen 9c3a1cfd47 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-21 12:08:03 +01:00
Till Faelligen 74da1f0fb3 Remove "magic" Enabled function and use simple bool 2022-02-16 10:11:23 +01:00
Till Faelligen 26accb8c5d Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-16 08:39:54 +01:00
Till Faelligen 6482630f7b Fix insert statement 2022-02-15 14:28:51 +01:00
Till Faelligen 2fc1c46743 Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking 2022-02-15 14:16:37 +01:00
Till Faelligen 5a0ec6e443 Add policy version to create-account & mediaapi 2022-02-15 14:15:18 +01:00
Till Faelligen 535d388ec0 Add new login type "m.login.terms" 2022-02-15 14:14:39 +01:00
Till Faelligen cbdbbb0839 Make sure we use the correct login stages 2022-02-15 14:13:22 +01:00
Till Faelligen f8bebe5e5a Add policy_version to insertAccount statement 2022-02-15 14:10:50 +01:00
Till Faelligen d19518fca5 Add ConsentNotGiven error
Verify consent on desired endpoints
Store consent on POST requests
2022-02-15 11:07:24 +01:00
Till Faelligen 89340cfc52 Verify the user has given their consent, otherwise block access 2022-02-14 18:11:56 +01:00
Till Faelligen 11144de92f Implement consent tracking 2022-02-14 16:18:51 +01:00
Till Faelligen b2045c24cb Add missing yaml tag 2022-02-14 16:18:19 +01:00
Till Faelligen 097f1d4609 Add a way to update the policy_version for a user 2022-02-14 15:08:00 +01:00
Till Faelligen a505471c90 Add table migrations 2022-02-14 14:52:16 +01:00
Till Faelligen 3c5c3ea7fb Add methods to query the policy version 2022-02-14 14:03:30 +01:00
Till Faelligen 9583784e8a Add new coloumn to track accepted policy version 2022-02-14 14:02:13 +01:00
Till Faelligen b6ee34918c Add consent tracking endpoint 2022-02-14 13:41:21 +01:00
Till Faelligen ac343861ad Add missing form_secret
Add tests
2022-02-14 13:06:36 +01:00
Till Faelligen 4da7df5e3e Add consent tracking template 2022-02-14 13:01:26 +01:00
Till Faelligen ccc11f94f7 Add consentAPIMux to components 2022-02-14 13:00:07 +01:00
Till Faelligen 5702b84dae Add User consent configuration
Add consentAPIMux
2022-02-14 12:59:13 +01:00
40 changed files with 1740 additions and 203 deletions

View file

@ -11,4 +11,5 @@ const (
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service" LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token" LoginTypeToken = "m.login.token"
LoginTypeTerms = "m.login.terms"
) )

View file

@ -29,6 +29,13 @@ type MatrixError struct {
Err string `json:"error"` Err string `json:"error"`
} }
// ConsentError is an error returned to users, who didn't accept the
// TOS of this server yet.
type ConsentError struct {
MatrixError
ConsentURI string `json:"consent_uri"`
}
func (e MatrixError) Error() string { func (e MatrixError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrCode, e.Err) return fmt.Sprintf("%s: %s", e.ErrCode, e.Err)
} }
@ -207,3 +214,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName), Err: fmt.Sprintf("Untrusted server '%s'", serverName),
} }
} }
// ConsentNotGiven is an error returned to users, who didn't accept the
// TOS of this server yet.
func ConsentNotGiven(consentURI string, msg string) *ConsentError {
return &ConsentError{
MatrixError: MatrixError{
ErrCode: "M_CONSENT_NOT_GIVEN",
Err: msg,
},
ConsentURI: consentURI,
}
}

View file

@ -0,0 +1,219 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// The data used to populate the /consent request
type constentTemplateData struct {
UserID string
Version string
UserHMAC string
HasConsented bool
ReadOnly bool
}
func writeHeaderAndText(w http.ResponseWriter, statusCode int) {
w.WriteHeader(statusCode)
_, _ = w.Write([]byte(http.StatusText(statusCode)))
}
func consent(writer http.ResponseWriter, req *http.Request, userAPI userapi.UserConsentPolicyAPI, cfg *config.ClientAPI) {
consentCfg := cfg.Matrix.UserConsentOptions
// The data used to populate the /consent request
data := constentTemplateData{
UserID: req.FormValue("u"),
Version: req.FormValue("v"),
UserHMAC: req.FormValue("h"),
}
switch req.Method {
case http.MethodGet:
// display the privacy policy without a form
data.ReadOnly = data.UserID == "" || data.UserHMAC == "" || data.Version == ""
// let's see if the user already consented to the current version
if !data.ReadOnly {
if ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret); err != nil || !ok {
writeHeaderAndText(writer, http.StatusForbidden)
return
}
res := &userapi.QueryPolicyVersionResponse{}
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
if err != nil {
logrus.WithError(err).Error("unable to split username")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{
Localpart: localpart,
}, res); err != nil {
logrus.WithError(err).Error("unable query policy version")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
data.HasConsented = res.PolicyVersion == consentCfg.Version
}
err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to execute consent template")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
case http.MethodPost:
ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret)
if err != nil || !ok {
if !ok {
writeHeaderAndText(writer, http.StatusForbidden)
return
}
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
if err != nil {
logrus.WithError(err).Error("unable to split username")
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
if err = userAPI.PerformUpdatePolicyVersion(
req.Context(),
&userapi.UpdatePolicyVersionRequest{
PolicyVersion: data.Version,
Localpart: localpart,
},
&userapi.UpdatePolicyVersionResponse{},
); err != nil {
writeHeaderAndText(writer, http.StatusInternalServerError)
return
}
// display the privacy policy without a form
data.ReadOnly = false
data.HasConsented = true
err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
if err != nil {
logrus.WithError(err).Error("unable to print consent template")
writeHeaderAndText(writer, http.StatusInternalServerError)
}
}
}
func sendServerNoticeForConsent(userAPI userapi.ClientUserAPI, rsAPI api.ClientRoomserverAPI,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
senderDevice *userapi.Device,
asAPI appserviceAPI.AppServiceInternalAPI,
) {
res := &userapi.QueryOutdatedPolicyResponse{}
if err := userAPI.QueryOutdatedPolicy(context.Background(), &userapi.QueryOutdatedPolicyRequest{
PolicyVersion: cfgClient.Matrix.UserConsentOptions.Version,
}, res); err != nil {
logrus.WithError(err).Error("unable to fetch users with outdated consent policy")
return
}
var (
consentOpts = cfgClient.Matrix.UserConsentOptions
data = make(map[string]string)
err error
sentMessages int
)
if len(res.UserLocalparts) == 0 {
return
}
logrus.WithField("count", len(res.UserLocalparts)).Infof("Sending server notice to users who have not yet accepted the policy")
for _, localpart := range res.UserLocalparts {
if localpart == cfgClient.Matrix.ServerNotices.LocalPart {
continue
}
userID := fmt.Sprintf("@%s:%s", localpart, cfgClient.Matrix.ServerName)
data["ConsentURL"], err = consentOpts.ConsentURL(userID)
if err != nil {
logrus.WithError(err).WithField("userID", userID).Error("unable to construct consentURI")
continue
}
msgBody := &bytes.Buffer{}
if err = consentOpts.TextTemplates.ExecuteTemplate(msgBody, "serverNoticeTemplate", data); err != nil {
logrus.WithError(err).WithField("userID", userID).Error("unable to execute serverNoticeTemplate")
continue
}
req := sendServerNoticeRequest{
UserID: userID,
Content: struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
MsgType: consentOpts.ServerNoticeContent.MsgType,
Body: msgBody.String(),
},
}
_, err = sendServerNotice(context.Background(), req, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, nil, nil, nil)
if err != nil {
logrus.WithError(err).WithField("userID", userID).Error("failed to send server notice for consent to user")
continue
}
sentMessages++
res := &userapi.UpdatePolicyVersionResponse{}
if err = userAPI.PerformUpdatePolicyVersion(context.Background(), &userapi.UpdatePolicyVersionRequest{
PolicyVersion: consentOpts.Version,
Localpart: userID,
ServerNoticeUpdate: true,
}, res); err != nil {
logrus.WithError(err).WithField("userID", userID).Error("failed to update policy version")
continue
}
}
if sentMessages > 0 {
logrus.Infof("Sent messages to %d users", sentMessages)
}
}
func validHMAC(username, userHMAC, secret string) (bool, error) {
mac := hmac.New(sha256.New, []byte(secret))
_, err := mac.Write([]byte(username))
if err != nil {
return false, err
}
expectedMAC := mac.Sum(nil)
decoded, err := hex.DecodeString(userHMAC)
if err != nil {
return false, err
}
return hmac.Equal(decoded, expectedMAC), nil
}

View file

@ -0,0 +1,236 @@
package routing
import (
"context"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
func Test_validHMAC(t *testing.T) {
type args struct {
username string
userHMAC string
secret string
}
tests := []struct {
name string
args args
want bool
wantErr bool
}{
{
name: "invalid hmac",
args: args{},
wantErr: false,
want: false,
},
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
//
{
name: "valid hmac",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
secret: "helloWorld",
},
want: true,
},
{
name: "invalid hmac",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
secret: "helloWorld",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
if (err != nil) != tt.wantErr {
t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
}
})
}
}
type dummyAPI struct {
usersConsent map[string]string
}
func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error {
return nil
}
func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error {
d.usersConsent[req.Localpart] = req.PolicyVersion
return nil
}
func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error {
res.PolicyVersion = "v2.0"
return nil
}
const dummyTemplate = `
{{ if .HasConsented }}
Consent given.
{{ else }}
WithoutForm
{{ if not .ReadOnly }}
With Form.
{{ end }}
{{ end }}`
func Test_consent(t *testing.T) {
type args struct {
username string
userHMAC string
version string
method string
}
tests := []struct {
name string
args args
wantRespCode int
wantBodyContains string
}{
{
name: "not a userID, valid hmac",
args: args{
username: "notAuserID",
userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusInternalServerError,
},
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
//
{
name: "valid hmac for alice GET, not consented",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "With form",
},
{
name: "alice consents successfully",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodPost,
},
wantRespCode: http.StatusOK,
wantBodyContains: "Consent given",
},
{
name: "valid hmac for alice GET, new version",
args: args{
username: "@alice:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v2.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "With form",
},
{
name: "no hmac provided for alice, read only should be displayed",
args: args{
username: "@alice:localhost",
userHMAC: "",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusOK,
wantBodyContains: "WithoutForm",
},
{
name: "alice trying to get bobs status is forbidden",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodGet,
},
wantRespCode: http.StatusForbidden,
wantBodyContains: "forbidden",
},
{
name: "alice trying to consent for bob is forbidden",
args: args{
username: "@bob:localhost",
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
version: "v1.0",
method: http.MethodPost,
},
wantRespCode: http.StatusForbidden,
wantBodyContains: "forbidden",
},
}
userAPI := dummyAPI{
usersConsent: map[string]string{},
}
consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate))
consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate))
userconsentOpts := config.UserConsentOptions{
FormSecret: "helloWorld",
Version: "v1.0",
Templates: consentTemplates,
BaseURL: "http://localhost",
}
cfg := &config.ClientAPI{
Matrix: &config.Global{
UserConsentOptions: userconsentOpts,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s",
userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC,
)
req := httptest.NewRequest(tt.args.method, url, nil)
w := httptest.NewRecorder()
consent(w, req, userAPI, cfg)
resp := w.Result()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read response body: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantRespCode {
t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode)
}
if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) {
t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body))
}
})
}
}

View file

@ -31,13 +31,12 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/tidwall/gjson"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -721,6 +720,8 @@ func handleRegistrationFlow(
} }
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeTerms:
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms)
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
@ -788,11 +789,16 @@ func handleApplicationServiceRegistration(
return *err return *err
} }
policyVersion := ""
if cfg.Matrix.UserConsentOptions.Enabled {
policyVersion = cfg.Matrix.UserConsentOptions.Version
}
// If no error, application service was successfully validated. // If no error, application service was successfully validated.
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
) )
} }
@ -809,9 +815,13 @@ func checkAndCompleteFlow(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
policyVersion := ""
if cfg.Matrix.UserConsentOptions.Enabled {
policyVersion = cfg.Matrix.UserConsentOptions.Version
}
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
) )
} }
@ -834,7 +844,7 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string, username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
@ -861,11 +871,12 @@ func completeRegistration(
} }
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
Password: password, Password: password,
AccountType: accType, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
PolicyVersion: policyVersion,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
@ -1073,5 +1084,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
@ -93,7 +94,7 @@ func PutTag(
} }
tagContent.Tags[tag] = properties tagContent.Tags[tag] = properties
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -145,7 +146,7 @@ func DeleteTag(
} }
} }
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -191,7 +192,7 @@ func obtainSavedTags(
// saveTagData saves the provided tag data into the database // saveTagData saves the provided tag data into the database
func saveTagData( func saveTagData(
req *http.Request, context context.Context,
userID string, userID string,
roomID string, roomID string,
userAPI api.ClientUserAPI, userAPI api.ClientUserAPI,
@ -208,5 +209,5 @@ func saveTagData(
AccountData: json.RawMessage(newTagData), AccountData: json.RawMessage(newTagData),
} }
dataRes := api.InputAccountDataResponse{} dataRes := api.InputAccountDataResponse{}
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) return userAPI.InputAccountData(context, &dataReq, &dataRes)
} }

View file

@ -127,9 +127,13 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// server notifications // server notifications
var (
serverNotificationSender *userapi.Device
err error
)
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) serverNotificationSender, err = getSenderDevice(context.Background(), userAPI, cfg)
if err != nil { if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices") logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
} }
@ -177,13 +181,27 @@ func Setup(
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching. // using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing! // Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
// NOTSPEC: consent tracking
if cfg.Matrix.UserConsentOptions.Enabled {
if !cfg.Matrix.ServerNotices.Enabled {
logrus.Warnf("Consent tracking is enabled, but server notes are not. No server notice will be sent to users")
} else {
// start a new go routine to send messages about consent
go sendServerNoticeForConsent(userAPI, rsAPI, &cfg.Matrix.ServerNotices, cfg, serverNotificationSender, asAPI)
}
publicAPIMux.HandleFunc("/consent", func(writer http.ResponseWriter, request *http.Request) {
consent(writer, request, userAPI, cfg)
}).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
}
consentRequiredCheck := httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)
v3mux.Handle("/createRoom", v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI) return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -213,7 +231,7 @@ func Setup(
return PeekRoomByIDOrAlias( return PeekRoomByIDOrAlias(
req, device, rsAPI, vars["roomIDOrAlias"], req, device, rsAPI, vars["roomIDOrAlias"],
) )
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
} }
v3mux.Handle("/joined_rooms", v3mux.Handle("/joined_rooms",
@ -258,7 +276,7 @@ func Setup(
return UnpeekRoomByID( return UnpeekRoomByID(
req, device, rsAPI, vars["roomID"], req, device, rsAPI, vars["roomID"],
) )
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/ban", v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -267,7 +285,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -279,7 +297,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/kick", v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -288,7 +306,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unban", v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -297,7 +315,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}", v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -317,7 +335,7 @@ func Setup(
txnID := vars["txnID"] txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/event/{eventID}", v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -326,7 +344,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -335,7 +353,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -343,7 +361,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAliases(req, rsAPI, device, vars["roomID"]) return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -354,7 +372,7 @@ func Setup(
eventType := strings.TrimSuffix(vars["type"], "/") eventType := strings.TrimSuffix(vars["type"], "/")
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -363,7 +381,7 @@ func Setup(
} }
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -374,7 +392,7 @@ func Setup(
emptyString := "" emptyString := ""
eventType := strings.TrimSuffix(vars["eventType"], "/") eventType := strings.TrimSuffix(vars["eventType"], "/")
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
@ -385,7 +403,7 @@ func Setup(
} }
stateKey := vars["stateKey"] stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
@ -487,7 +505,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil) return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -497,7 +515,7 @@ func Setup(
} }
txnID := vars["txnId"] txnID := vars["txnId"]
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache) return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/sendToDevice/{eventType}/{txnID}", v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
@ -508,7 +526,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// This is only here because sytest refers to /unstable for this endpoint // This is only here because sytest refers to /unstable for this endpoint
@ -522,7 +540,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
@ -531,7 +549,7 @@ func Setup(
return *r return *r
} }
return Whoami(req, device) return Whoami(req, device)
}), }, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
@ -738,7 +756,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI) return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
@ -763,7 +781,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
@ -771,19 +789,19 @@ func Setup(
v3mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, userAPI, device) return GetAssociated3PIDs(req, userAPI, device)
}), }, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, userAPI, device, cfg) return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/account/3pid/delete", unstableMux.Handle("/account/3pid/delete",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Forget3PID(req, userAPI) return Forget3PID(req, userAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
@ -798,7 +816,7 @@ func Setup(
return *r return *r
} }
return RequestTurnServer(req, device, cfg) return RequestTurnServer(req, device, cfg)
}), }, consentRequiredCheck),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/protocols", v3mux.Handle("/thirdparty/protocols",
@ -868,7 +886,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAdminWhois(req, userAPI, device, vars["userID"]) return GetAdminWhois(req, userAPI, device, vars["userID"])
}), }, consentRequiredCheck),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
v3mux.Handle("/user/{userID}/openid/request_token", v3mux.Handle("/user/{userID}/openid/request_token",
@ -881,7 +899,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg) return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/user_directory/search", v3mux.Handle("/user_directory/search",
@ -907,7 +925,7 @@ func Setup(
postContent.SearchString, postContent.SearchString,
postContent.Limit, postContent.Limit,
) )
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/members", v3mux.Handle("/rooms/{roomID}/members",
@ -953,7 +971,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendForget(req, device, vars["roomID"], rsAPI) return SendForget(req, device, vars["roomID"], rsAPI)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/upgrade", v3mux.Handle("/rooms/{roomID}/upgrade",
@ -1065,7 +1083,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }, consentRequiredCheck),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
@ -1075,7 +1093,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }, consentRequiredCheck),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
v3mux.Handle("/capabilities", v3mux.Handle("/capabilities",
@ -1095,11 +1113,11 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return KeyBackupVersion(req, userAPI, device, vars["version"]) return KeyBackupVersion(req, userAPI, device, vars["version"])
}) }, consentRequiredCheck)
getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return KeyBackupVersion(req, userAPI, device, "") return KeyBackupVersion(req, userAPI, device, "")
}) }, consentRequiredCheck)
putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1107,7 +1125,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
}) }, consentRequiredCheck)
deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1119,7 +1137,7 @@ func Setup(
postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateKeyBackupVersion(req, userAPI, device) return CreateKeyBackupVersion(req, userAPI, device)
}) }, consentRequiredCheck)
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1150,7 +1168,7 @@ func Setup(
return *resErr return *resErr
} }
return UploadBackupKeys(req, userAPI, device, version, &reqBody) return UploadBackupKeys(req, userAPI, device, version, &reqBody)
}) }, consentRequiredCheck)
// Single room bulk session // Single room bulk session
putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1182,7 +1200,7 @@ func Setup(
} }
reqBody.Rooms[roomID] = body reqBody.Rooms[roomID] = body
return UploadBackupKeys(req, userAPI, device, version, &reqBody) return UploadBackupKeys(req, userAPI, device, version, &reqBody)
}) }, consentRequiredCheck)
// Single room, single session // Single room, single session
putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1215,7 +1233,7 @@ func Setup(
} }
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
return UploadBackupKeys(req, userAPI, device, version, &keyReq) return UploadBackupKeys(req, userAPI, device, version, &keyReq)
}) }, consentRequiredCheck)
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1229,7 +1247,7 @@ func Setup(
getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
}) }, consentRequiredCheck)
getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1245,7 +1263,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
}) }, consentRequiredCheck)
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1261,11 +1279,11 @@ func Setup(
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
}) }, consentRequiredCheck)
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceSignatures(req, keyAPI, device) return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
}) }, consentRequiredCheck)
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
@ -1277,12 +1295,12 @@ func Setup(
v3mux.Handle("/keys/upload/{deviceID}", v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload", v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -1305,7 +1323,7 @@ func Setup(
} }
return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"]) return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
}), }, consentRequiredCheck),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/presence/{userId}/status", v3mux.Handle("/presence/{userId}/status",
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -84,74 +84,66 @@ func SendServerNotice(
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
res, _ := sendServerNotice(ctx, r, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, txnID, device, txnCache)
return res
}
func sendServerNotice(
ctx context.Context,
serverNoticeRequest sendServerNoticeRequest,
rsAPI api.ClientRoomserverAPI,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
senderDevice *userapi.Device,
asAPI appserviceAPI.AppServiceInternalAPI,
userAPI userapi.ClientUserAPI,
txnID *string,
device *userapi.Device,
txnCache *transactions.Cache,
) (util.JSONResponse, error) {
// check that all required fields are set // check that all required fields are set
if !r.valid() { if !serverNoticeRequest.valid() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid request"), JSON: jsonerror.BadJSON("Invalid request"),
} }, fmt.Errorf("Invalid JSON")
} }
// get rooms for specified user qryServerNoticeRoom := &userapi.QueryServerNoticeRoomResponse{}
allUserRooms := []string{} localpart, _, err := gomatrixserverlib.SplitID('@', serverNoticeRequest.UserID)
userRooms := api.QueryRoomsForUserResponse{} if err != nil {
// Get rooms the user is either joined, invited or has left. return util.JSONResponse{
for _, membership := range []string{"join", "invite", "leave"} { Code: http.StatusBadRequest,
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ JSON: jsonerror.BadJSON("Invalid request"),
UserID: r.UserID, }, err
WantMembership: membership, }
}, &userRooms); err != nil { err = userAPI.SelectServerNoticeRoomID(ctx, &userapi.QueryServerNoticeRoomRequest{Localpart: localpart}, qryServerNoticeRoom)
return util.ErrorResponse(err) if err != nil {
} return util.ErrorResponse(err), err
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
} }
// get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
senderRooms := api.QueryRoomsForUserResponse{} roomID := qryServerNoticeRoom.RoomID
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ roomVersion := version.DefaultRoomVersion()
UserID: senderUserID,
WantMembership: "join",
}, &senderRooms); err != nil {
return util.ErrorResponse(err)
}
// check if we have rooms in common
commonRooms := []string{}
for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs {
if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID)
}
}
}
if len(commonRooms) > 1 {
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
}
var (
roomID string
roomVersion = version.DefaultRoomVersion()
)
// create a new room for the user // create a new room for the user
if len(commonRooms) == 0 { if qryServerNoticeRoom.RoomID == "" {
var pl, cc []byte
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse powerLevelContent.Users[serverNoticeRequest.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent) pl, err = json.Marshal(powerLevelContent)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err), err
} }
createContent := map[string]interface{}{} createContent := map[string]interface{}{}
createContent["m.federate"] = false createContent["m.federate"] = false
cc, err := json.Marshal(createContent) cc, err = json.Marshal(createContent)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err), err
} }
crReq := createRoomRequest{ crReq := createRoomRequest{
Invite: []string{r.UserID}, Invite: []string{serverNoticeRequest.UserID},
Name: cfgNotices.RoomName, Name: cfgNotices.RoomName,
Visibility: "private", Visibility: "private",
Preset: presetPrivateChat, Preset: presetPrivateChat,
@ -166,36 +158,40 @@ func SendServerNotice(
switch data := roomRes.JSON.(type) { switch data := roomRes.JSON.(type) {
case createRoomResponse: case createRoomResponse:
roomID = data.RoomID roomID = data.RoomID
res := &userapi.UpdateServerNoticeRoomResponse{}
err = userAPI.UpdateServerNoticeRoomID(ctx, &userapi.UpdateServerNoticeRoomRequest{RoomID: roomID, Localpart: localpart}, res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("UpdateServerNoticeRoomID failed")
return jsonerror.InternalServerError(), err
}
// tag the room, so we can later check if the user tries to reject an invite // tag the room, so we can later check if the user tries to reject an invite
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{ serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
"m.server_notice": { "m.server_notice": {
Order: 1.0, Order: 1.0,
}, },
}} }}
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { if err = saveTagData(ctx, serverNoticeRequest.UserID, roomID, userAPI, serverAlertTag); err != nil {
util.GetLogger(ctx).WithError(err).Error("saveTagData failed") util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
} }
default: default:
// if we didn't get a createRoomResponse, we probably received an error, so return that. // if we didn't get a createRoomResponse, we probably received an error, so return that.
return roomRes return roomRes, fmt.Errorf("Unable to create room")
} }
} else { } else {
// we've found a room in common, check the membership res := &api.QueryMembershipForUserResponse{}
roomID = commonRooms[0] err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: serverNoticeRequest.UserID, RoomID: roomID}, res)
membershipRes := api.QueryMembershipForUserResponse{}
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") return util.ErrorResponse(err), err
return jsonerror.InternalServerError()
} }
if !membershipRes.IsInRoom { // re-invite the user
// re-invite the user if res.Membership != gomatrixserverlib.Join {
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) var inviteRes util.JSONResponse
inviteRes, err = sendInvite(ctx, userAPI, senderDevice, roomID, serverNoticeRequest.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
if err != nil { if err != nil {
return res return inviteRes, err
} }
} }
} }
@ -203,13 +199,13 @@ func SendServerNotice(
startedGeneratingEvent := time.Now() startedGeneratingEvent := time.Now()
request := map[string]interface{}{ request := map[string]interface{}{
"body": r.Content.Body, "body": serverNoticeRequest.Content.Body,
"msgtype": r.Content.MsgType, "msgtype": serverNoticeRequest.Content.MsgType,
} }
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
if resErr != nil { if resErr != nil {
logrus.Errorf("failed to send message: %+v", resErr) logrus.Errorf("failed to send message: %+v", resErr)
return *resErr return *resErr, fmt.Errorf("Unable to send event")
} }
timeToGenerateEvent := time.Since(startedGeneratingEvent) timeToGenerateEvent := time.Since(startedGeneratingEvent)
@ -224,7 +220,7 @@ func SendServerNotice(
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now() startedSubmittingEvent := time.Now()
if err := api.SendEvents( if err = api.SendEvents(
ctx, rsAPI, ctx, rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
@ -236,7 +232,7 @@ func SendServerNotice(
false, false,
); err != nil { ); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
} }
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": e.EventID(), "event_id": e.EventID(),
@ -259,7 +255,7 @@ func SendServerNotice(
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds())) sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds())) sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res return res, nil
} }
func (r sendServerNoticeRequest) valid() (ok bool) { func (r sendServerNoticeRequest) valid() (ok bool) {

View file

@ -146,7 +146,12 @@ func main() {
logrus.Fatalln("Username is already in use.") logrus.Fatalln("Username is already in use.")
} }
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) policyVersion := ""
if cfg.Global.UserConsentOptions.Enabled {
policyVersion = cfg.Global.UserConsentOptions.Version
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
} }

View file

@ -71,6 +71,35 @@ global:
# appear in user clients. # appear in user clients.
room_name: "Server Alerts" room_name: "Server Alerts"
# Consent tracking configuration
user_consent:
# If the user consent tracking is enabled or not
enabled: false
# The base URL this homeserver will serve clients on, e.g. https://matrix.org
base_url: http://localhost
# Randomly generated string (e.g. by using "pwgen -sy 32") to be used to calculate the HMAC
form_secret: "superSecretRandomlyGeneratedSecret"
# Require consent when user registers for the first time
require_at_registration: false
# The name to be shown to the user
policy_name: "Privacy policy"
# The directory to search for templates
template_dir: "./templates/privacy"
# The version of the policy. When loading templates, ".gohtml" template is added as a suffix
# e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to "1.0"
version: "1.0"
# Send a consent message to guest users
send_server_notice_to_guest: false
# Default message to send to users
server_notice_content:
msg_type: "m.text"
body: >-
Please give your consent to the privacy policy at {{ .ConsentURL }}.
# The error message to display if the user hasn't given their consent yet
block_events_error: >-
You can't send any messages until you consent to the privacy policy at
{{ .ConsentURL }}.
# Configuration for NATS JetStream # Configuration for NATS JetStream
jetstream: jetstream:
# A list of NATS Server addresses to connect to. If none are specified, an # A list of NATS Server addresses to connect to. If none are specified, an

26
docs/templates/privacy/1.0.gohtml vendored Normal file
View file

@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<title>Privacy policy</title>
</head>
<body>
{{ if .HasConsented }}
<p>
You have already given your consent.
</p>
{{ else }}
<p>
Please give your consent to keep using this homeserver.
</p>
{{ if not .ReadOnly }}
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
<form method="post" action="consent">
<input type="hidden" name="v" value="{{ .Version }}"/>
<input type="hidden" name="u" value="{{ .UserID }}"/>
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
<input type="submit" value="I consent"/>
</form>
{{ end }}
{{ end }}
</body>
</html>

View file

@ -15,6 +15,8 @@
package httputil package httputil
import ( import (
"bytes"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -25,9 +27,12 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/ext"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@ -41,10 +46,24 @@ type BasicAuth struct {
Password string `yaml:"password"` Password string `yaml:"password"`
} }
// AuthAPICheck is an option to MakeAuthAPI to add additional checks (e.g. WithConsentCheck) to verify
// the user is allowed to do specific things.
type AuthAPICheck func(ctx context.Context, device *userapi.Device) *util.JSONResponse
// WithConsentCheck checks that a user has given his consent.
func WithConsentCheck(options config.UserConsentOptions, api userapi.QueryPolicyVersionAPI) AuthAPICheck {
return func(ctx context.Context, device *userapi.Device) *util.JSONResponse {
if !options.Enabled {
return nil
}
return checkConsent(ctx, device.UserID, api, options)
}
}
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI( func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI, metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse, f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPICheck,
) http.Handler { ) http.Handler {
h := func(req *http.Request) util.JSONResponse { h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -72,6 +91,14 @@ func MakeAuthAPI(
} }
}() }()
// apply additional checks, if any
for _, opt := range checks {
resp := opt(req.Context(), device)
if resp != nil {
return *resp
}
}
jsonRes := f(req, device) jsonRes := f(req, device)
// do not log 4xx as errors as they are client fails, not server fails // do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 { if hub != nil && jsonRes.Code >= 500 {
@ -83,6 +110,53 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h) return MakeExternalAPI(metricsName, h)
} }
func checkConsent(ctx context.Context, userID string, userAPI userapi.QueryPolicyVersionAPI, userConsentCfg config.UserConsentOptions) *util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil
}
// check which version of the policy the user accepted
res := &userapi.QueryPolicyVersionResponse{}
err = userAPI.QueryPolicyVersion(ctx, &userapi.QueryPolicyVersionRequest{
Localpart: localpart,
}, res)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to get policy version"),
}
}
// user hasn't accepted any policy, block access.
if userConsentCfg.Version != res.PolicyVersion {
uri, err := userConsentCfg.ConsentURL(userID)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to get consent URL"),
}
}
msg := &bytes.Buffer{}
c := struct {
ConsentURL string
}{
ConsentURL: uri,
}
if err = userConsentCfg.TextTemplates.ExecuteTemplate(msg, "blockEventsError", c); err != nil {
logrus.Infof("error consent message: %+v", err)
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to execute template"),
}
}
return &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.ConsentNotGiven(uri, msg.String()),
}
}
return nil
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet. // This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {

View file

@ -59,15 +59,12 @@ func Setup(
PathToResult: map[string]*types.ThumbnailGenerationResult{}, PathToResult: map[string]*types.ThumbnailGenerationResult{},
} }
uploadHandler := httputil.MakeAuthAPI( uploadHandler := httputil.MakeAuthAPI("upload", userAPI, func(req *http.Request, dev *userapi.Device) util.JSONResponse {
"upload", userAPI, if r := rateLimits.Limit(req); r != nil {
func(req *http.Request, dev *userapi.Device) util.JSONResponse { return *r
if r := rateLimits.Limit(req); r != nil { }
return *r return Upload(req, cfg, dev, db, activeThumbnailGeneration)
} })
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
},
)
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {

View file

@ -268,6 +268,7 @@ func (b *BaseDendrite) Close() error {
func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) { func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) {
if dbProperties.ConnectionString != "" || b == nil { if dbProperties.ConnectionString != "" || b == nil {
// Open a new database connection using the supplied config. // Open a new database connection using the supplied config.
logrus.Infof("Open a new database connection using the supplied config.: %+v", dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties, writer) db, err := sqlutil.Open(dbProperties, writer)
return db, writer, err return db, writer, err
} }

View file

@ -265,6 +265,21 @@ func loadConfig(
return &c, nil return &c, nil
} }
type Terms struct {
Policies Policies `json:"policies"`
}
type En struct {
Name string `json:"name"`
URL string `json:"url"`
}
type PrivacyPolicy struct {
En En `json:"en"`
Version string `json:"version"`
}
type Policies struct {
PrivacyPolicy PrivacyPolicy `json:"privacy_policy"`
}
// Derive generates data that is derived from various values provided in // Derive generates data that is derived from various values provided in
// the config file. // the config file.
func (config *Dendrite) Derive() error { func (config *Dendrite) Derive() error {
@ -275,13 +290,39 @@ func (config *Dendrite) Derive() error {
// TODO: Add email auth type // TODO: Add email auth type
// TODO: Add MSISDN auth type // TODO: Add MSISDN auth type
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
uri := config.Global.UserConsentOptions.BaseURL + "/_matrix/client/consent?v=" + config.Global.UserConsentOptions.Version
config.Derived.Registration.Params[authtypes.LoginTypeTerms] = Terms{
Policies: Policies{
PrivacyPolicy: PrivacyPolicy{
En: En{
Name: config.Global.UserConsentOptions.PolicyName,
URL: uri,
},
Version: config.Global.UserConsentOptions.Version,
},
},
}
}
if config.ClientAPI.RecaptchaEnabled { if config.ClientAPI.RecaptchaEnabled {
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, }
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
} else { if config.Derived.Registration.Flows == nil {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) Stages: []authtypes.LoginType{authtypes.LoginTypeDummy},
})
}
// prepend each flow with LoginTypeTerms or LoginTypeRecaptcha
for i, flow := range config.Derived.Registration.Flows {
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeTerms}, flow.Stages...)
}
if config.ClientAPI.RecaptchaEnabled {
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeRecaptcha}, flow.Stages...)
}
config.Derived.Registration.Flows[i] = flow
} }
// Load application service configuration files // Load application service configuration files

View file

@ -1,7 +1,15 @@
package config package config
import ( import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"html/template"
"math/rand" "math/rand"
"net/url"
"path/filepath"
textTemplate "text/template"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -71,6 +79,9 @@ type Global struct {
// ServerNotices configuration used for sending server notices // ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"` ServerNotices ServerNotices `yaml:"server_notices"`
// Consent tracking options
UserConsentOptions UserConsentOptions `yaml:"user_consent"`
// ReportStats configures opt-in anonymous stats reporting. // ReportStats configures opt-in anonymous stats reporting.
ReportStats ReportStats `yaml:"report_stats"` ReportStats ReportStats `yaml:"report_stats"`
} }
@ -88,6 +99,7 @@ func (c *Global) Defaults(generate bool) {
c.Metrics.Defaults(generate) c.Metrics.Defaults(generate)
c.DNSCache.Defaults() c.DNSCache.Defaults()
c.Sentry.Defaults() c.Sentry.Defaults()
c.UserConsentOptions.Defaults()
c.ServerNotices.Defaults(generate) c.ServerNotices.Defaults(generate)
c.ReportStats.Defaults() c.ReportStats.Defaults()
} }
@ -100,6 +112,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Metrics.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith)
c.Sentry.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith)
c.DNSCache.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith)
c.UserConsentOptions.Verify(configErrs, isMonolith)
c.ServerNotices.Verify(configErrs, isMonolith) c.ServerNotices.Verify(configErrs, isMonolith)
c.ReportStats.Verify(configErrs, isMonolith) c.ReportStats.Verify(configErrs, isMonolith)
} }
@ -261,6 +274,102 @@ func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime)) checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime))
} }
// Consent tracking configuration
// If either require_at_registration or send_server_notice_to_guest are true, consent
// messages will be sent to the users.
type UserConsentOptions struct {
// If consent tracking is enabled or not
Enabled bool `yaml:"enabled"`
// Randomly generated string to be used to calculate the HMAC
FormSecret string `yaml:"form_secret"`
// Require consent when user registers for the first time
RequireAtRegistration bool `yaml:"require_at_registration"`
// The name to be shown to the user
PolicyName string `yaml:"policy_name"`
// The directory to search for *.gohtml templates
TemplateDir string `yaml:"template_dir"`
// The version of the policy. When loading templates, ".gohtml" template is added as a suffix
// e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to 1.0
Version string `yaml:"version"`
// Send a consent message to guest users
SendServerNoticeToGuest bool `yaml:"send_server_notice_to_guest"`
// Default message to send to users
ServerNoticeContent struct {
MsgType string `yaml:"msg_type"`
Body string `yaml:"body"`
} `yaml:"server_notice_content"`
// The error message to display if the user hasn't given their consent yet
BlockEventsError string `yaml:"block_events_error"`
// All loaded templates
Templates *template.Template `yaml:"-"`
TextTemplates *textTemplate.Template `yaml:"-"`
// The base URL this homeserver will serve clients on, e.g. https://matrix.org
BaseURL string `yaml:"base_url"`
}
func (c *UserConsentOptions) Defaults() {
c.Enabled = false
c.RequireAtRegistration = false
c.SendServerNoticeToGuest = false
c.PolicyName = "Privacy Policy"
c.Version = "1.0"
c.TemplateDir = "./templates/privacy"
}
func (c *UserConsentOptions) Verify(configErrors *ConfigErrors, isMonolith bool) {
if !c.Enabled {
return
}
checkNotEmpty(configErrors, "template_dir", c.TemplateDir)
checkNotEmpty(configErrors, "version", c.Version)
checkNotEmpty(configErrors, "policy_name", c.PolicyName)
checkNotEmpty(configErrors, "form_secret", c.FormSecret)
checkNotEmpty(configErrors, "base_url", c.BaseURL)
if len(*configErrors) > 0 {
return
}
p, err := filepath.Abs(c.TemplateDir)
if err != nil {
configErrors.Add("unable to get template directory")
return
}
c.TextTemplates = textTemplate.Must(textTemplate.New("blockEventsError").Parse(c.BlockEventsError))
c.TextTemplates = textTemplate.Must(c.TextTemplates.New("serverNoticeTemplate").Parse(c.ServerNoticeContent.Body))
// Read all defined *.gohtml templates
t, err := template.ParseGlob(filepath.Join(p, "*.gohtml"))
if err != nil || t == nil {
configErrors.Add(fmt.Sprintf("unable to read consent templates: %+v", err))
return
}
c.Templates = t
// Verify we've got a template for the defined version
versionTemplate := c.Templates.Lookup(c.Version + ".gohtml")
if versionTemplate == nil {
configErrors.Add(fmt.Sprintf("unable to load defined '%s' policy template", c.Version))
}
}
// ConsentURL constructs the URL shown to users to accept the TOS
func (c *UserConsentOptions) ConsentURL(userID string) (string, error) {
mac := hmac.New(sha256.New, []byte(c.FormSecret))
_, err := mac.Write([]byte(userID))
if err != nil {
return "", err
}
userMAC := hex.EncodeToString(mac.Sum(nil))
params := url.Values{}
params.Add("u", userID)
params.Add("h", userMAC)
params.Add("v", c.Version)
return fmt.Sprintf("%s/_matrix/client/consent?%s", c.BaseURL, params.Encode()), nil
}
// PresenceOptions defines possible configurations for presence events. // PresenceOptions defines possible configurations for presence events.
type PresenceOptions struct { type PresenceOptions struct {
// Whether inbound presence events are allowed // Whether inbound presence events are allowed

View file

@ -0,0 +1,115 @@
package config
import (
"testing"
)
func TestUserConsentOptions_Verify(t *testing.T) {
type args struct {
configErrors *ConfigErrors
isMonolith bool
}
tests := []struct {
name string
fields UserConsentOptions
args args
wantErr bool
}{
{
name: "template dir not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "template dir set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "policy name not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "policy name set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
PolicyName: "Privacy policy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "version not set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "testdata/privacy",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: true,
},
{
name: "everyhing required set",
fields: UserConsentOptions{
RequireAtRegistration: true,
TemplateDir: "./testdata/privacy",
Version: "1.0",
PolicyName: "Privacy policy",
FormSecret: "helloWorld",
BaseURL: "http://localhost",
},
args: struct {
configErrors *ConfigErrors
isMonolith bool
}{configErrors: &ConfigErrors{}, isMonolith: true},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &UserConsentOptions{
Enabled: true,
BaseURL: tt.fields.BaseURL,
FormSecret: tt.fields.FormSecret,
RequireAtRegistration: tt.fields.RequireAtRegistration,
PolicyName: tt.fields.PolicyName,
Version: tt.fields.Version,
TemplateDir: tt.fields.TemplateDir,
SendServerNoticeToGuest: tt.fields.SendServerNoticeToGuest,
ServerNoticeContent: tt.fields.ServerNoticeContent,
BlockEventsError: tt.fields.BlockEventsError,
}
c.Verify(tt.args.configErrors, tt.args.isMonolith)
if !tt.wantErr && len(*tt.args.configErrors) > 0 {
t.Errorf("expected no errors, got '%+v'", tt.args.configErrors)
}
})
}
}

View file

@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<title>Privacy policy</title>
</head>
<body>
{{ if .HasConsented }}
<p>
You have already given your consent.
</p>
{{ else }}
<p>
Please give your consent to keep using this homeserver.
</p>
{{ if not .ReadOnly }}
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
<form method="post" action="consent">
<input type="hidden" name="v" value="{{ .Version }}"/>
<input type="hidden" name="u" value="{{ .UserID }}"/>
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
<input type="submit" value="I consent"/>
</form>
{{ end }}
{{ end }}
</body>
</html>

View file

@ -44,7 +44,6 @@ func ParseFlags(monolith bool) *config.Dendrite {
} }
cfg, err := config.Load(*configPath, monolith) cfg, err := config.Load(*configPath, monolith)
if err != nil { if err != nil {
logrus.Fatalf("Invalid config file: %s", err) logrus.Fatalf("Invalid config file: %s", err)
} }

View file

@ -93,6 +93,6 @@ func Setup(
vars["roomId"], vars["eventId"], vars["roomId"], vars["eventId"],
lazyLoadCache, lazyLoadCache,
) )
}), }, httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
} }

View file

@ -66,6 +66,7 @@ type FederationUserAPI interface {
// api functions required by the sync api // api functions required by the sync api
type SyncUserAPI interface { type SyncUserAPI interface {
QueryAcccessTokenAPI QueryAcccessTokenAPI
QueryPolicyVersionAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -78,6 +79,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI QueryAcccessTokenAPI
LoginTokenInternalAPI LoginTokenInternalAPI
UserLoginAPI UserLoginAPI
UserConsentPolicyAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -106,6 +108,18 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
}
type UserConsentPolicyAPI interface {
QueryPolicyVersionAPI
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
}
type QueryPolicyVersionAPI interface {
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
} }
// custom api functions required by pinecone / p2p demos // custom api functions required by pinecone / p2p demos
@ -318,12 +332,12 @@ type QuerySearchProfilesResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct { type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account AccountType AccountType // Required: whether this is a guest or user account
Localpart string // Required: The localpart for this account. Ignored if account type is guest. Localpart string // Required: The localpart for this account. Ignored if account type is guest.
PolicyVersion string // optional: the privacy policy this account has accepted
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account Password string // optional: if missing then this account will be a passwordless account
OnConflict Conflict OnConflict Conflict
} }
// PerformAccountCreationResponse is the response for PerformAccountCreation // PerformAccountCreationResponse is the response for PerformAccountCreation
@ -412,6 +426,53 @@ type QueryOpenIDTokenResponse struct {
ExpiresAtMS int64 ExpiresAtMS int64
} }
// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
type QueryPolicyVersionRequest struct {
Localpart string
}
// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
type QueryPolicyVersionResponse struct {
PolicyVersion string
}
// QueryOutdatedPolicyRequest is the request for QueryOutdatedPolicyRequest
type QueryOutdatedPolicyRequest struct {
PolicyVersion string
}
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
type QueryOutdatedPolicyResponse struct {
UserLocalparts []string
}
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
type UpdatePolicyVersionRequest struct {
PolicyVersion, Localpart string
ServerNoticeUpdate bool
}
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
type UpdatePolicyVersionResponse struct{}
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomRequest struct {
Localpart string
}
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomResponse struct {
RoomID string
}
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomRequest struct {
Localpart, RoomID string
}
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomResponse struct{}
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string

View file

@ -203,6 +203,36 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err return err
} }
func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error {
err := t.Impl.QueryPolicyVersion(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error {
err := t.Impl.QueryOutdatedPolicy(ctx, req, res)
util.GetLogger(ctx).Infof("QueryOutdatedPolicy req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -66,7 +66,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -833,3 +833,60 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
} }
const pushRulesAccountDataType = "m.push_rules" const pushRulesAccountDataType = "m.push_rules"
func (a *UserInternalAPI) QueryPolicyVersion(
ctx context.Context,
req *api.QueryPolicyVersionRequest,
res *api.QueryPolicyVersionResponse,
) error {
var err error
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
if err != nil {
return err
}
return nil
}
func (a *UserInternalAPI) QueryOutdatedPolicy(
ctx context.Context,
req *api.QueryOutdatedPolicyRequest,
res *api.QueryOutdatedPolicyResponse,
) error {
var err error
res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
if err != nil {
return err
}
return nil
}
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
ctx context.Context,
req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse,
) error {
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
}
func (a *UserInternalAPI) SelectServerNoticeRoomID(
ctx context.Context,
req *api.QueryServerNoticeRoomRequest,
res *api.QueryServerNoticeRoomResponse,
) (err error) {
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
if err != nil {
return err
}
res.RoomID = roomID
return nil
}
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
ctx context.Context,
req *api.UpdateServerNoticeRoomRequest,
res *api.UpdateServerNoticeRoomResponse,
) (err error) {
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
}

View file

@ -44,6 +44,8 @@ const (
PerformSetDisplayNamePath = "/userapi/performSetDisplayName" PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
PerformForgetThreePIDPath = "/userapi/performForgetThreePID" PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation" PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
@ -61,6 +63,9 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -391,3 +396,43 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) QueryOutdatedPolicy(ctx context.Context, req *api.QueryOutdatedPolicyRequest, res *api.QueryOutdatedPolicyResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOutdatedPolicy")
defer span.Finish()
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUpdatePolicyVersion")
defer span.Finish()
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + QueryServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion")
defer span.Finish()
apiURL := h.apiURL + QueryPolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -457,4 +457,74 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}), }),
) )
internalAPIMux.Handle(QueryPolicyVersionPath,
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
request := api.QueryPolicyVersionRequest{}
response := api.QueryPolicyVersionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.QueryPolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryOutdatedPolicyRequest{}
response := api.QueryOutdatedPolicyResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.QueryOutdatedPolicy(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
request := api.UpdatePolicyVersionRequest{}
response := api.UpdatePolicyVersionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryServerNoticeRoomPath,
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerNoticeRoomRequest{}
response := api.QueryServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.UpdateServerNoticeRoomRequest{}
response := api.UpdateServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -36,7 +36,7 @@ type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
@ -126,9 +126,18 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error DeleteOldNotifications(ctx context.Context) error
} }
type ConsentTracking interface {
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
}
type Database interface { type Database interface {
Account Account
AccountData AccountData
ConsentTracking
Device Device
KeyBackup KeyBackup
LoginToken LoginToken

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE, is_deactivated BOOLEAN DEFAULT FALSE,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL account_type SMALLINT NOT NULL,
-- The policy version this user has accepted
policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -67,14 +73,38 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
} }
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
@ -92,6 +122,12 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -99,16 +135,16 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -178,3 +214,71 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return id + 1, err return id + 1, err
} }
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
var policyNull sql.NullString
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
return policyNull.String, err
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) {
stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -12,6 +12,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType) goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
} }
func LoadIsActive(m *sqlutil.Migrations) { func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -0,0 +1,45 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -46,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -125,7 +126,7 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part // For guest accounts, we create a new numeric local part
@ -139,7 +140,7 @@ func (d *Database) CreateAccount(
plaintextPassword = "" plaintextPassword = ""
appserviceID = "" appserviceID = ""
} }
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType)
return err return err
}) })
return return
@ -148,7 +149,7 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -160,7 +161,8 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil {
logrus.WithError(err).Error("d.Accounts.InsertAccount error")
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
@ -763,3 +765,42 @@ func (d *Database) RemovePushers(
func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) { func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) {
return d.Stats.UserStatistics(ctx, nil) return d.Stats.UserStatistics(ctx, nil)
} }
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
})
return
}
// SelectServerNoticeRoomID returns the server notice room, if one is set.
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
}
// UpdateServerNoticeRoomID updates the server notice room
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
})
return
}

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0, is_deactivated BOOLEAN DEFAULT 0,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type INTEGER NOT NULL account_type INTEGER NOT NULL,
-- The policy version this user has accepted
policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -67,15 +73,39 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
} }
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
@ -94,6 +124,12 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -101,16 +137,16 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -183,3 +219,72 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
} }
return id + 1, err return id + 1, err
} }
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
var policyNull sql.NullString
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
return policyNull.String, err
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) {
stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -12,6 +12,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType) goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
} }
func LoadIsActive(m *sqlutil.Migrations) { func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -4,15 +4,9 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
) )
func init() {
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadAddAccountType(m *sqlutil.Migrations) { func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType) m.AddMigration(UpAddAccountType, DownAddAccountType)
} }

View file

@ -0,0 +1,44 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}
func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -18,11 +18,10 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
@ -47,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -78,7 +78,7 @@ func Test_Accounts(t *testing.T) {
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount // verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account var accGet *api.Account
@ -102,7 +102,7 @@ func Test_Accounts(t *testing.T) {
first, err := db.GetNewNumericLocalpart(ctx) first, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err, "failed to get new numeric localpart") assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated // Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", "testing", "", "v1.0", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx) second, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err) assert.NoError(t, err)
@ -350,7 +350,7 @@ func Test_Profile(t *testing.T) {
defer close() defer close()
// create account, which also creates a profile // create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)

View file

@ -32,12 +32,18 @@ type AccountDataTable interface {
} }
type AccountsTable interface { type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
} }
type DevicesTable interface { type DevicesTable interface {

View file

@ -88,7 +88,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16) appServiceID = util.RandomString(16)
} }
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) _, err := accDB.InsertAccount(ctx, nil, localpart, "", "", appServiceID, accType)
if err != nil { if err != nil {
t.Fatalf("unable to create account: %v", err) t.Fatalf("unable to create account: %v", err)
} }

View file

@ -75,7 +75,7 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close() defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -154,7 +154,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close() defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }