Compare commits
82 commits
main
...
s7evink/co
Author | SHA1 | Date | |
---|---|---|---|
cabc5f4bc9 | |||
7df7a966b8 | |||
75ca5490bc | |||
c9409078ac | |||
3d3773d3d4 | |||
4bd9a73c13 | |||
4d285fff60 | |||
10dc02f1ea | |||
94ed2d3689 | |||
964e1cef85 | |||
2eb3aab07e | |||
60ba4b5612 | |||
88612ddd0c | |||
bddf8ed3ac | |||
4d5feb2544 | |||
cd7a7606a1 | |||
dc8cea6d57 | |||
e0cdf64c33 | |||
ef62255685 | |||
1f64fc79c8 | |||
2b496be2c3 | |||
733b601aa9 | |||
2a18023a1a | |||
c99e3aff1b | |||
f1e8d19cea | |||
019f0922ea | |||
6324c1d01f | |||
b9479a6f18 | |||
e42ef1706b | |||
31ac3ac081 | |||
39d9d88b02 | |||
710007d600 | |||
dcfc0bcd43 | |||
c7d2254698 | |||
7c6a162c0f | |||
699617ee4d | |||
519ea13510 | |||
fa26aa9138 | |||
e80ca307d3 | |||
df7218e230 | |||
e6e62497c9 | |||
2ad15f308f | |||
ed16a2f107 | |||
ce658ab8f2 | |||
79e1c9e4bd | |||
2042303c6c | |||
4f2d161401 | |||
c65eb2bf52 | |||
0ae8293abd | |||
dac29c1786 | |||
c2b6019c35 | |||
61cdb714df | |||
185cb7a582 | |||
e2b0ff675b | |||
c0845ea1ad | |||
6622fda08c | |||
219a15c4c3 | |||
fb95331aa2 | |||
cb4526793d | |||
2e6987f8bd | |||
9c3a1cfd47 | |||
74da1f0fb3 | |||
26accb8c5d | |||
6482630f7b | |||
2fc1c46743 | |||
5a0ec6e443 | |||
535d388ec0 | |||
cbdbbb0839 | |||
f8bebe5e5a | |||
d19518fca5 | |||
89340cfc52 | |||
11144de92f | |||
b2045c24cb | |||
097f1d4609 | |||
a505471c90 | |||
3c5c3ea7fb | |||
9583784e8a | |||
b6ee34918c | |||
ac343861ad | |||
4da7df5e3e | |||
ccc11f94f7 | |||
5702b84dae |
|
@ -11,4 +11,5 @@ const (
|
||||||
LoginTypeRecaptcha = "m.login.recaptcha"
|
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||||
LoginTypeApplicationService = "m.login.application_service"
|
LoginTypeApplicationService = "m.login.application_service"
|
||||||
LoginTypeToken = "m.login.token"
|
LoginTypeToken = "m.login.token"
|
||||||
|
LoginTypeTerms = "m.login.terms"
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,6 +29,13 @@ type MatrixError struct {
|
||||||
Err string `json:"error"`
|
Err string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConsentError is an error returned to users, who didn't accept the
|
||||||
|
// TOS of this server yet.
|
||||||
|
type ConsentError struct {
|
||||||
|
MatrixError
|
||||||
|
ConsentURI string `json:"consent_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
func (e MatrixError) Error() string {
|
func (e MatrixError) Error() string {
|
||||||
return fmt.Sprintf("%s: %s", e.ErrCode, e.Err)
|
return fmt.Sprintf("%s: %s", e.ErrCode, e.Err)
|
||||||
}
|
}
|
||||||
|
@ -207,3 +214,15 @@ func NotTrusted(serverName string) *MatrixError {
|
||||||
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
|
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConsentNotGiven is an error returned to users, who didn't accept the
|
||||||
|
// TOS of this server yet.
|
||||||
|
func ConsentNotGiven(consentURI string, msg string) *ConsentError {
|
||||||
|
return &ConsentError{
|
||||||
|
MatrixError: MatrixError{
|
||||||
|
ErrCode: "M_CONSENT_NOT_GIVEN",
|
||||||
|
Err: msg,
|
||||||
|
},
|
||||||
|
ConsentURI: consentURI,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
219
clientapi/routing/consent_tracking.go
Normal file
219
clientapi/routing/consent_tracking.go
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The data used to populate the /consent request
|
||||||
|
type constentTemplateData struct {
|
||||||
|
UserID string
|
||||||
|
Version string
|
||||||
|
UserHMAC string
|
||||||
|
HasConsented bool
|
||||||
|
ReadOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeHeaderAndText(w http.ResponseWriter, statusCode int) {
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
_, _ = w.Write([]byte(http.StatusText(statusCode)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func consent(writer http.ResponseWriter, req *http.Request, userAPI userapi.UserConsentPolicyAPI, cfg *config.ClientAPI) {
|
||||||
|
consentCfg := cfg.Matrix.UserConsentOptions
|
||||||
|
|
||||||
|
// The data used to populate the /consent request
|
||||||
|
data := constentTemplateData{
|
||||||
|
UserID: req.FormValue("u"),
|
||||||
|
Version: req.FormValue("v"),
|
||||||
|
UserHMAC: req.FormValue("h"),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch req.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
// display the privacy policy without a form
|
||||||
|
data.ReadOnly = data.UserID == "" || data.UserHMAC == "" || data.Version == ""
|
||||||
|
|
||||||
|
// let's see if the user already consented to the current version
|
||||||
|
if !data.ReadOnly {
|
||||||
|
if ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret); err != nil || !ok {
|
||||||
|
writeHeaderAndText(writer, http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := &userapi.QueryPolicyVersionResponse{}
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("unable to split username")
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
}, res); err != nil {
|
||||||
|
logrus.WithError(err).Error("unable query policy version")
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data.HasConsented = res.PolicyVersion == consentCfg.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("unable to execute consent template")
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case http.MethodPost:
|
||||||
|
ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret)
|
||||||
|
if err != nil || !ok {
|
||||||
|
if !ok {
|
||||||
|
writeHeaderAndText(writer, http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("unable to split username")
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = userAPI.PerformUpdatePolicyVersion(
|
||||||
|
req.Context(),
|
||||||
|
&userapi.UpdatePolicyVersionRequest{
|
||||||
|
PolicyVersion: data.Version,
|
||||||
|
Localpart: localpart,
|
||||||
|
},
|
||||||
|
&userapi.UpdatePolicyVersionResponse{},
|
||||||
|
); err != nil {
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// display the privacy policy without a form
|
||||||
|
data.ReadOnly = false
|
||||||
|
data.HasConsented = true
|
||||||
|
|
||||||
|
err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("unable to print consent template")
|
||||||
|
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendServerNoticeForConsent(userAPI userapi.ClientUserAPI, rsAPI api.ClientRoomserverAPI,
|
||||||
|
cfgNotices *config.ServerNotices,
|
||||||
|
cfgClient *config.ClientAPI,
|
||||||
|
senderDevice *userapi.Device,
|
||||||
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
) {
|
||||||
|
res := &userapi.QueryOutdatedPolicyResponse{}
|
||||||
|
if err := userAPI.QueryOutdatedPolicy(context.Background(), &userapi.QueryOutdatedPolicyRequest{
|
||||||
|
PolicyVersion: cfgClient.Matrix.UserConsentOptions.Version,
|
||||||
|
}, res); err != nil {
|
||||||
|
logrus.WithError(err).Error("unable to fetch users with outdated consent policy")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
consentOpts = cfgClient.Matrix.UserConsentOptions
|
||||||
|
data = make(map[string]string)
|
||||||
|
err error
|
||||||
|
sentMessages int
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(res.UserLocalparts) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithField("count", len(res.UserLocalparts)).Infof("Sending server notice to users who have not yet accepted the policy")
|
||||||
|
|
||||||
|
for _, localpart := range res.UserLocalparts {
|
||||||
|
if localpart == cfgClient.Matrix.ServerNotices.LocalPart {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID := fmt.Sprintf("@%s:%s", localpart, cfgClient.Matrix.ServerName)
|
||||||
|
data["ConsentURL"], err = consentOpts.ConsentURL(userID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("userID", userID).Error("unable to construct consentURI")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msgBody := &bytes.Buffer{}
|
||||||
|
|
||||||
|
if err = consentOpts.TextTemplates.ExecuteTemplate(msgBody, "serverNoticeTemplate", data); err != nil {
|
||||||
|
logrus.WithError(err).WithField("userID", userID).Error("unable to execute serverNoticeTemplate")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
req := sendServerNoticeRequest{
|
||||||
|
UserID: userID,
|
||||||
|
Content: struct {
|
||||||
|
MsgType string `json:"msgtype,omitempty"`
|
||||||
|
Body string `json:"body,omitempty"`
|
||||||
|
}{
|
||||||
|
MsgType: consentOpts.ServerNoticeContent.MsgType,
|
||||||
|
Body: msgBody.String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err = sendServerNotice(context.Background(), req, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("userID", userID).Error("failed to send server notice for consent to user")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sentMessages++
|
||||||
|
res := &userapi.UpdatePolicyVersionResponse{}
|
||||||
|
if err = userAPI.PerformUpdatePolicyVersion(context.Background(), &userapi.UpdatePolicyVersionRequest{
|
||||||
|
PolicyVersion: consentOpts.Version,
|
||||||
|
Localpart: userID,
|
||||||
|
ServerNoticeUpdate: true,
|
||||||
|
}, res); err != nil {
|
||||||
|
logrus.WithError(err).WithField("userID", userID).Error("failed to update policy version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sentMessages > 0 {
|
||||||
|
logrus.Infof("Sent messages to %d users", sentMessages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validHMAC(username, userHMAC, secret string) (bool, error) {
|
||||||
|
mac := hmac.New(sha256.New, []byte(secret))
|
||||||
|
_, err := mac.Write([]byte(username))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
expectedMAC := mac.Sum(nil)
|
||||||
|
decoded, err := hex.DecodeString(userHMAC)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return hmac.Equal(decoded, expectedMAC), nil
|
||||||
|
}
|
236
clientapi/routing/consent_tracking_test.go
Normal file
236
clientapi/routing/consent_tracking_test.go
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_validHMAC(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
username string
|
||||||
|
userHMAC string
|
||||||
|
secret string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid hmac",
|
||||||
|
args: args{},
|
||||||
|
wantErr: false,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
|
||||||
|
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
|
||||||
|
//
|
||||||
|
{
|
||||||
|
name: "valid hmac",
|
||||||
|
args: args{
|
||||||
|
username: "@alice:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
secret: "helloWorld",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid hmac",
|
||||||
|
args: args{
|
||||||
|
username: "@bob:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
secret: "helloWorld",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyAPI struct {
|
||||||
|
usersConsent map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error {
|
||||||
|
d.usersConsent[req.Localpart] = req.PolicyVersion
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error {
|
||||||
|
res.PolicyVersion = "v2.0"
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const dummyTemplate = `
|
||||||
|
{{ if .HasConsented }}
|
||||||
|
Consent given.
|
||||||
|
{{ else }}
|
||||||
|
WithoutForm
|
||||||
|
{{ if not .ReadOnly }}
|
||||||
|
With Form.
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}`
|
||||||
|
|
||||||
|
func Test_consent(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
username string
|
||||||
|
userHMAC string
|
||||||
|
version string
|
||||||
|
method string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantRespCode int
|
||||||
|
wantBodyContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "not a userID, valid hmac",
|
||||||
|
args: args{
|
||||||
|
username: "notAuserID",
|
||||||
|
userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodGet,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
|
||||||
|
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
|
||||||
|
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
|
||||||
|
//
|
||||||
|
{
|
||||||
|
name: "valid hmac for alice GET, not consented",
|
||||||
|
args: args{
|
||||||
|
username: "@alice:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodGet,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusOK,
|
||||||
|
wantBodyContains: "With form",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alice consents successfully",
|
||||||
|
args: args{
|
||||||
|
username: "@alice:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodPost,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusOK,
|
||||||
|
wantBodyContains: "Consent given",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid hmac for alice GET, new version",
|
||||||
|
args: args{
|
||||||
|
username: "@alice:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
version: "v2.0",
|
||||||
|
method: http.MethodGet,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusOK,
|
||||||
|
wantBodyContains: "With form",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no hmac provided for alice, read only should be displayed",
|
||||||
|
args: args{
|
||||||
|
username: "@alice:localhost",
|
||||||
|
userHMAC: "",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodGet,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusOK,
|
||||||
|
wantBodyContains: "WithoutForm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alice trying to get bobs status is forbidden",
|
||||||
|
args: args{
|
||||||
|
username: "@bob:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodGet,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusForbidden,
|
||||||
|
wantBodyContains: "forbidden",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alice trying to consent for bob is forbidden",
|
||||||
|
args: args{
|
||||||
|
username: "@bob:localhost",
|
||||||
|
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||||
|
version: "v1.0",
|
||||||
|
method: http.MethodPost,
|
||||||
|
},
|
||||||
|
wantRespCode: http.StatusForbidden,
|
||||||
|
wantBodyContains: "forbidden",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userAPI := dummyAPI{
|
||||||
|
usersConsent: map[string]string{},
|
||||||
|
}
|
||||||
|
consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate))
|
||||||
|
consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate))
|
||||||
|
userconsentOpts := config.UserConsentOptions{
|
||||||
|
FormSecret: "helloWorld",
|
||||||
|
Version: "v1.0",
|
||||||
|
Templates: consentTemplates,
|
||||||
|
BaseURL: "http://localhost",
|
||||||
|
}
|
||||||
|
cfg := &config.ClientAPI{
|
||||||
|
Matrix: &config.Global{
|
||||||
|
UserConsentOptions: userconsentOpts,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s",
|
||||||
|
userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC,
|
||||||
|
)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(tt.args.method, url, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
consent(w, req, userAPI, cfg)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to read response body: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantRespCode {
|
||||||
|
t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) {
|
||||||
|
t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -31,13 +31,12 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/tokens"
|
"github.com/matrix-org/gomatrixserverlib/tokens"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
@ -721,6 +720,8 @@ func handleRegistrationFlow(
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
|
case authtypes.LoginTypeTerms:
|
||||||
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms)
|
||||||
case authtypes.LoginTypeRecaptcha:
|
case authtypes.LoginTypeRecaptcha:
|
||||||
// Check given captcha response
|
// Check given captcha response
|
||||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||||
|
@ -788,11 +789,16 @@ func handleApplicationServiceRegistration(
|
||||||
return *err
|
return *err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policyVersion := ""
|
||||||
|
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||||
|
policyVersion = cfg.Matrix.UserConsentOptions.Version
|
||||||
|
}
|
||||||
|
|
||||||
// If no error, application service was successfully validated.
|
// If no error, application service was successfully validated.
|
||||||
// Don't need to worry about appending to registration stages as
|
// Don't need to worry about appending to registration stages as
|
||||||
// application service registration is entirely separate.
|
// application service registration is entirely separate.
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
|
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion,
|
||||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -809,9 +815,13 @@ func checkAndCompleteFlow(
|
||||||
userAPI userapi.ClientUserAPI,
|
userAPI userapi.ClientUserAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||||
|
policyVersion := ""
|
||||||
|
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||||
|
policyVersion = cfg.Matrix.UserConsentOptions.Version
|
||||||
|
}
|
||||||
// This flow was completed, registration can continue
|
// This flow was completed, registration can continue
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
|
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion,
|
||||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -834,7 +844,7 @@ func checkAndCompleteFlow(
|
||||||
func completeRegistration(
|
func completeRegistration(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userAPI userapi.ClientUserAPI,
|
userAPI userapi.ClientUserAPI,
|
||||||
username, password, appserviceID, ipAddr, userAgent, sessionID string,
|
username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string,
|
||||||
inhibitLogin eventutil.WeakBoolean,
|
inhibitLogin eventutil.WeakBoolean,
|
||||||
displayName, deviceID *string,
|
displayName, deviceID *string,
|
||||||
accType userapi.AccountType,
|
accType userapi.AccountType,
|
||||||
|
@ -866,6 +876,7 @@ func completeRegistration(
|
||||||
Password: password,
|
Password: password,
|
||||||
AccountType: accType,
|
AccountType: accType,
|
||||||
OnConflict: userapi.ConflictAbort,
|
OnConflict: userapi.ConflictAbort,
|
||||||
|
PolicyVersion: policyVersion,
|
||||||
}, &accRes)
|
}, &accRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
||||||
|
@ -1073,5 +1084,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
|
||||||
if ssrr.Admin {
|
if ssrr.Admin {
|
||||||
accType = userapi.AccountTypeAdmin
|
accType = userapi.AccountTypeAdmin
|
||||||
}
|
}
|
||||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
|
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -93,7 +94,7 @@ func PutTag(
|
||||||
}
|
}
|
||||||
tagContent.Tags[tag] = properties
|
tagContent.Tags[tag] = properties
|
||||||
|
|
||||||
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -145,7 +146,7 @@ func DeleteTag(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -191,7 +192,7 @@ func obtainSavedTags(
|
||||||
|
|
||||||
// saveTagData saves the provided tag data into the database
|
// saveTagData saves the provided tag data into the database
|
||||||
func saveTagData(
|
func saveTagData(
|
||||||
req *http.Request,
|
context context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
userAPI api.ClientUserAPI,
|
userAPI api.ClientUserAPI,
|
||||||
|
@ -208,5 +209,5 @@ func saveTagData(
|
||||||
AccountData: json.RawMessage(newTagData),
|
AccountData: json.RawMessage(newTagData),
|
||||||
}
|
}
|
||||||
dataRes := api.InputAccountDataResponse{}
|
dataRes := api.InputAccountDataResponse{}
|
||||||
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
|
return userAPI.InputAccountData(context, &dataReq, &dataRes)
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,9 +127,13 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
// server notifications
|
// server notifications
|
||||||
|
var (
|
||||||
|
serverNotificationSender *userapi.Device
|
||||||
|
err error
|
||||||
|
)
|
||||||
if cfg.Matrix.ServerNotices.Enabled {
|
if cfg.Matrix.ServerNotices.Enabled {
|
||||||
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
|
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
|
||||||
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
|
serverNotificationSender, err = getSenderDevice(context.Background(), userAPI, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
|
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
|
||||||
}
|
}
|
||||||
|
@ -177,13 +181,27 @@ func Setup(
|
||||||
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
|
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
|
||||||
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
|
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
|
||||||
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||||
|
|
||||||
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
|
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
|
||||||
|
|
||||||
|
// NOTSPEC: consent tracking
|
||||||
|
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||||
|
if !cfg.Matrix.ServerNotices.Enabled {
|
||||||
|
logrus.Warnf("Consent tracking is enabled, but server notes are not. No server notice will be sent to users")
|
||||||
|
} else {
|
||||||
|
// start a new go routine to send messages about consent
|
||||||
|
go sendServerNoticeForConsent(userAPI, rsAPI, &cfg.Matrix.ServerNotices, cfg, serverNotificationSender, asAPI)
|
||||||
|
}
|
||||||
|
publicAPIMux.HandleFunc("/consent", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
consent(writer, request, userAPI, cfg)
|
||||||
|
}).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
consentRequiredCheck := httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)
|
||||||
|
|
||||||
v3mux.Handle("/createRoom",
|
v3mux.Handle("/createRoom",
|
||||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
|
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/join/{roomIDOrAlias}",
|
v3mux.Handle("/join/{roomIDOrAlias}",
|
||||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -213,7 +231,7 @@ func Setup(
|
||||||
return PeekRoomByIDOrAlias(
|
return PeekRoomByIDOrAlias(
|
||||||
req, device, rsAPI, vars["roomIDOrAlias"],
|
req, device, rsAPI, vars["roomIDOrAlias"],
|
||||||
)
|
)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
}
|
}
|
||||||
v3mux.Handle("/joined_rooms",
|
v3mux.Handle("/joined_rooms",
|
||||||
|
@ -258,7 +276,7 @@ func Setup(
|
||||||
return UnpeekRoomByID(
|
return UnpeekRoomByID(
|
||||||
req, device, rsAPI, vars["roomID"],
|
req, device, rsAPI, vars["roomID"],
|
||||||
)
|
)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/ban",
|
v3mux.Handle("/rooms/{roomID}/ban",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -267,7 +285,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/invite",
|
v3mux.Handle("/rooms/{roomID}/invite",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -279,7 +297,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/kick",
|
v3mux.Handle("/rooms/{roomID}/kick",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -288,7 +306,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/unban",
|
v3mux.Handle("/rooms/{roomID}/unban",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -297,7 +315,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -317,7 +335,7 @@ func Setup(
|
||||||
txnID := vars["txnID"]
|
txnID := vars["txnID"]
|
||||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
|
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
|
||||||
nil, cfg, rsAPI, transactionsCache)
|
nil, cfg, rsAPI, transactionsCache)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||||
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -326,7 +344,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -335,7 +353,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
|
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
|
||||||
})).Methods(http.MethodGet, http.MethodOptions)
|
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -343,7 +361,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetAliases(req, rsAPI, device, vars["roomID"])
|
return GetAliases(req, rsAPI, device, vars["roomID"])
|
||||||
})).Methods(http.MethodGet, http.MethodOptions)
|
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -354,7 +372,7 @@ func Setup(
|
||||||
eventType := strings.TrimSuffix(vars["type"], "/")
|
eventType := strings.TrimSuffix(vars["type"], "/")
|
||||||
eventFormat := req.URL.Query().Get("format") == "event"
|
eventFormat := req.URL.Query().Get("format") == "event"
|
||||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
|
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
|
||||||
})).Methods(http.MethodGet, http.MethodOptions)
|
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -363,7 +381,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
eventFormat := req.URL.Query().Get("format") == "event"
|
eventFormat := req.URL.Query().Get("format") == "event"
|
||||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
|
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
|
||||||
})).Methods(http.MethodGet, http.MethodOptions)
|
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -374,7 +392,7 @@ func Setup(
|
||||||
emptyString := ""
|
emptyString := ""
|
||||||
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
||||||
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||||
|
@ -385,7 +403,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
stateKey := vars["stateKey"]
|
stateKey := vars["stateKey"]
|
||||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||||
|
@ -487,7 +505,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -497,7 +515,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
txnID := vars["txnId"]
|
txnID := vars["txnId"]
|
||||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
|
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||||
|
@ -508,7 +526,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
txnID := vars["txnID"]
|
txnID := vars["txnID"]
|
||||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
// This is only here because sytest refers to /unstable for this endpoint
|
// This is only here because sytest refers to /unstable for this endpoint
|
||||||
|
@ -522,7 +540,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
txnID := vars["txnID"]
|
txnID := vars["txnID"]
|
||||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/account/whoami",
|
v3mux.Handle("/account/whoami",
|
||||||
|
@ -531,7 +549,7 @@ func Setup(
|
||||||
return *r
|
return *r
|
||||||
}
|
}
|
||||||
return Whoami(req, device)
|
return Whoami(req, device)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/account/password",
|
v3mux.Handle("/account/password",
|
||||||
|
@ -738,7 +756,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||||
// PUT requests, so we need to allow this method
|
// PUT requests, so we need to allow this method
|
||||||
|
@ -763,7 +781,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||||
// PUT requests, so we need to allow this method
|
// PUT requests, so we need to allow this method
|
||||||
|
@ -771,19 +789,19 @@ func Setup(
|
||||||
v3mux.Handle("/account/3pid",
|
v3mux.Handle("/account/3pid",
|
||||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return GetAssociated3PIDs(req, userAPI, device)
|
return GetAssociated3PIDs(req, userAPI, device)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/account/3pid",
|
v3mux.Handle("/account/3pid",
|
||||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
|
return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
unstableMux.Handle("/account/3pid/delete",
|
unstableMux.Handle("/account/3pid/delete",
|
||||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return Forget3PID(req, userAPI)
|
return Forget3PID(req, userAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
||||||
|
@ -798,7 +816,7 @@ func Setup(
|
||||||
return *r
|
return *r
|
||||||
}
|
}
|
||||||
return RequestTurnServer(req, device, cfg)
|
return RequestTurnServer(req, device, cfg)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/thirdparty/protocols",
|
v3mux.Handle("/thirdparty/protocols",
|
||||||
|
@ -868,7 +886,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetAdminWhois(req, userAPI, device, vars["userID"])
|
return GetAdminWhois(req, userAPI, device, vars["userID"])
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
v3mux.Handle("/user/{userID}/openid/request_token",
|
v3mux.Handle("/user/{userID}/openid/request_token",
|
||||||
|
@ -881,7 +899,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
|
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/user_directory/search",
|
v3mux.Handle("/user_directory/search",
|
||||||
|
@ -907,7 +925,7 @@ func Setup(
|
||||||
postContent.SearchString,
|
postContent.SearchString,
|
||||||
postContent.Limit,
|
postContent.Limit,
|
||||||
)
|
)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/members",
|
v3mux.Handle("/rooms/{roomID}/members",
|
||||||
|
@ -953,7 +971,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendForget(req, device, vars["roomID"], rsAPI)
|
return SendForget(req, device, vars["roomID"], rsAPI)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/upgrade",
|
v3mux.Handle("/rooms/{roomID}/upgrade",
|
||||||
|
@ -1065,7 +1083,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||||
|
@ -1075,7 +1093,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodDelete, http.MethodOptions)
|
).Methods(http.MethodDelete, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/capabilities",
|
v3mux.Handle("/capabilities",
|
||||||
|
@ -1095,11 +1113,11 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return KeyBackupVersion(req, userAPI, device, vars["version"])
|
return KeyBackupVersion(req, userAPI, device, vars["version"])
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return KeyBackupVersion(req, userAPI, device, "")
|
return KeyBackupVersion(req, userAPI, device, "")
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -1107,7 +1125,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
|
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -1119,7 +1137,7 @@ func Setup(
|
||||||
|
|
||||||
postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return CreateKeyBackupVersion(req, userAPI, device)
|
return CreateKeyBackupVersion(req, userAPI, device)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
@ -1150,7 +1168,7 @@ func Setup(
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
// Single room bulk session
|
// Single room bulk session
|
||||||
putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -1182,7 +1200,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
reqBody.Rooms[roomID] = body
|
reqBody.Rooms[roomID] = body
|
||||||
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
// Single room, single session
|
// Single room, single session
|
||||||
putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -1215,7 +1233,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
||||||
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
||||||
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
||||||
|
@ -1229,7 +1247,7 @@ func Setup(
|
||||||
|
|
||||||
getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
|
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -1245,7 +1263,7 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
@ -1261,11 +1279,11 @@ func Setup(
|
||||||
|
|
||||||
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
|
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
|
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
|
||||||
})
|
}, consentRequiredCheck)
|
||||||
|
|
||||||
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
@ -1277,12 +1295,12 @@ func Setup(
|
||||||
v3mux.Handle("/keys/upload/{deviceID}",
|
v3mux.Handle("/keys/upload/{deviceID}",
|
||||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return UploadKeys(req, keyAPI, device)
|
return UploadKeys(req, keyAPI, device)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/keys/upload",
|
v3mux.Handle("/keys/upload",
|
||||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return UploadKeys(req, keyAPI, device)
|
return UploadKeys(req, keyAPI, device)
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/keys/query",
|
v3mux.Handle("/keys/query",
|
||||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -1305,7 +1323,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
||||||
}),
|
}, consentRequiredCheck),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/presence/{userId}/status",
|
v3mux.Handle("/presence/{userId}/status",
|
||||||
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
|
|
@ -84,74 +84,66 @@ func SendServerNotice(
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
res, _ := sendServerNotice(ctx, r, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, txnID, device, txnCache)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendServerNotice(
|
||||||
|
ctx context.Context,
|
||||||
|
serverNoticeRequest sendServerNoticeRequest,
|
||||||
|
rsAPI api.ClientRoomserverAPI,
|
||||||
|
cfgNotices *config.ServerNotices,
|
||||||
|
cfgClient *config.ClientAPI,
|
||||||
|
senderDevice *userapi.Device,
|
||||||
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
userAPI userapi.ClientUserAPI,
|
||||||
|
txnID *string,
|
||||||
|
device *userapi.Device,
|
||||||
|
txnCache *transactions.Cache,
|
||||||
|
) (util.JSONResponse, error) {
|
||||||
|
|
||||||
// check that all required fields are set
|
// check that all required fields are set
|
||||||
if !r.valid() {
|
if !serverNoticeRequest.valid() {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.BadJSON("Invalid request"),
|
JSON: jsonerror.BadJSON("Invalid request"),
|
||||||
}
|
}, fmt.Errorf("Invalid JSON")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get rooms for specified user
|
qryServerNoticeRoom := &userapi.QueryServerNoticeRoomResponse{}
|
||||||
allUserRooms := []string{}
|
localpart, _, err := gomatrixserverlib.SplitID('@', serverNoticeRequest.UserID)
|
||||||
userRooms := api.QueryRoomsForUserResponse{}
|
if err != nil {
|
||||||
// Get rooms the user is either joined, invited or has left.
|
return util.JSONResponse{
|
||||||
for _, membership := range []string{"join", "invite", "leave"} {
|
Code: http.StatusBadRequest,
|
||||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
JSON: jsonerror.BadJSON("Invalid request"),
|
||||||
UserID: r.UserID,
|
}, err
|
||||||
WantMembership: membership,
|
|
||||||
}, &userRooms); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
}
|
||||||
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
|
err = userAPI.SelectServerNoticeRoomID(ctx, &userapi.QueryServerNoticeRoomRequest{Localpart: localpart}, qryServerNoticeRoom)
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get rooms of the sender
|
|
||||||
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
|
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
|
||||||
senderRooms := api.QueryRoomsForUserResponse{}
|
roomID := qryServerNoticeRoom.RoomID
|
||||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
roomVersion := version.DefaultRoomVersion()
|
||||||
UserID: senderUserID,
|
|
||||||
WantMembership: "join",
|
|
||||||
}, &senderRooms); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if we have rooms in common
|
|
||||||
commonRooms := []string{}
|
|
||||||
for _, userRoomID := range allUserRooms {
|
|
||||||
for _, senderRoomID := range senderRooms.RoomIDs {
|
|
||||||
if userRoomID == senderRoomID {
|
|
||||||
commonRooms = append(commonRooms, senderRoomID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(commonRooms) > 1 {
|
|
||||||
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
roomID string
|
|
||||||
roomVersion = version.DefaultRoomVersion()
|
|
||||||
)
|
|
||||||
|
|
||||||
// create a new room for the user
|
// create a new room for the user
|
||||||
if len(commonRooms) == 0 {
|
if qryServerNoticeRoom.RoomID == "" {
|
||||||
|
var pl, cc []byte
|
||||||
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
|
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
|
||||||
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
|
powerLevelContent.Users[serverNoticeRequest.UserID] = -10 // taken from Synapse
|
||||||
pl, err := json.Marshal(powerLevelContent)
|
pl, err = json.Marshal(powerLevelContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err), err
|
||||||
}
|
}
|
||||||
createContent := map[string]interface{}{}
|
createContent := map[string]interface{}{}
|
||||||
createContent["m.federate"] = false
|
createContent["m.federate"] = false
|
||||||
cc, err := json.Marshal(createContent)
|
cc, err = json.Marshal(createContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err), err
|
||||||
}
|
}
|
||||||
crReq := createRoomRequest{
|
crReq := createRoomRequest{
|
||||||
Invite: []string{r.UserID},
|
Invite: []string{serverNoticeRequest.UserID},
|
||||||
Name: cfgNotices.RoomName,
|
Name: cfgNotices.RoomName,
|
||||||
Visibility: "private",
|
Visibility: "private",
|
||||||
Preset: presetPrivateChat,
|
Preset: presetPrivateChat,
|
||||||
|
@ -166,36 +158,40 @@ func SendServerNotice(
|
||||||
switch data := roomRes.JSON.(type) {
|
switch data := roomRes.JSON.(type) {
|
||||||
case createRoomResponse:
|
case createRoomResponse:
|
||||||
roomID = data.RoomID
|
roomID = data.RoomID
|
||||||
|
res := &userapi.UpdateServerNoticeRoomResponse{}
|
||||||
|
err = userAPI.UpdateServerNoticeRoomID(ctx, &userapi.UpdateServerNoticeRoomRequest{RoomID: roomID, Localpart: localpart}, res)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("UpdateServerNoticeRoomID failed")
|
||||||
|
return jsonerror.InternalServerError(), err
|
||||||
|
}
|
||||||
// tag the room, so we can later check if the user tries to reject an invite
|
// tag the room, so we can later check if the user tries to reject an invite
|
||||||
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
|
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
|
||||||
"m.server_notice": {
|
"m.server_notice": {
|
||||||
Order: 1.0,
|
Order: 1.0,
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil {
|
if err = saveTagData(ctx, serverNoticeRequest.UserID, roomID, userAPI, serverAlertTag); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
|
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// if we didn't get a createRoomResponse, we probably received an error, so return that.
|
// if we didn't get a createRoomResponse, we probably received an error, so return that.
|
||||||
return roomRes
|
return roomRes, fmt.Errorf("Unable to create room")
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// we've found a room in common, check the membership
|
res := &api.QueryMembershipForUserResponse{}
|
||||||
roomID = commonRooms[0]
|
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: serverNoticeRequest.UserID, RoomID: roomID}, res)
|
||||||
membershipRes := api.QueryMembershipForUserResponse{}
|
|
||||||
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
|
return util.ErrorResponse(err), err
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
}
|
||||||
if !membershipRes.IsInRoom {
|
|
||||||
// re-invite the user
|
// re-invite the user
|
||||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
if res.Membership != gomatrixserverlib.Join {
|
||||||
|
var inviteRes util.JSONResponse
|
||||||
|
inviteRes, err = sendInvite(ctx, userAPI, senderDevice, roomID, serverNoticeRequest.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res
|
return inviteRes, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,13 +199,13 @@ func SendServerNotice(
|
||||||
startedGeneratingEvent := time.Now()
|
startedGeneratingEvent := time.Now()
|
||||||
|
|
||||||
request := map[string]interface{}{
|
request := map[string]interface{}{
|
||||||
"body": r.Content.Body,
|
"body": serverNoticeRequest.Content.Body,
|
||||||
"msgtype": r.Content.MsgType,
|
"msgtype": serverNoticeRequest.Content.MsgType,
|
||||||
}
|
}
|
||||||
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
|
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
logrus.Errorf("failed to send message: %+v", resErr)
|
logrus.Errorf("failed to send message: %+v", resErr)
|
||||||
return *resErr
|
return *resErr, fmt.Errorf("Unable to send event")
|
||||||
}
|
}
|
||||||
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
||||||
|
|
||||||
|
@ -224,7 +220,7 @@ func SendServerNotice(
|
||||||
// pass the new event to the roomserver and receive the correct event ID
|
// pass the new event to the roomserver and receive the correct event ID
|
||||||
// event ID in case of duplicate transaction is discarded
|
// event ID in case of duplicate transaction is discarded
|
||||||
startedSubmittingEvent := time.Now()
|
startedSubmittingEvent := time.Now()
|
||||||
if err := api.SendEvents(
|
if err = api.SendEvents(
|
||||||
ctx, rsAPI,
|
ctx, rsAPI,
|
||||||
api.KindNew,
|
api.KindNew,
|
||||||
[]*gomatrixserverlib.HeaderedEvent{
|
[]*gomatrixserverlib.HeaderedEvent{
|
||||||
|
@ -236,7 +232,7 @@ func SendServerNotice(
|
||||||
false,
|
false,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
|
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError(), err
|
||||||
}
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"event_id": e.EventID(),
|
"event_id": e.EventID(),
|
||||||
|
@ -259,7 +255,7 @@ func SendServerNotice(
|
||||||
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
|
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
|
||||||
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
|
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
|
||||||
|
|
||||||
return res
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r sendServerNoticeRequest) valid() (ok bool) {
|
func (r sendServerNoticeRequest) valid() (ok bool) {
|
||||||
|
|
|
@ -146,7 +146,12 @@ func main() {
|
||||||
logrus.Fatalln("Username is already in use.")
|
logrus.Fatalln("Username is already in use.")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
|
policyVersion := ""
|
||||||
|
if cfg.Global.UserConsentOptions.Enabled {
|
||||||
|
policyVersion = cfg.Global.UserConsentOptions.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,6 +71,35 @@ global:
|
||||||
# appear in user clients.
|
# appear in user clients.
|
||||||
room_name: "Server Alerts"
|
room_name: "Server Alerts"
|
||||||
|
|
||||||
|
# Consent tracking configuration
|
||||||
|
user_consent:
|
||||||
|
# If the user consent tracking is enabled or not
|
||||||
|
enabled: false
|
||||||
|
# The base URL this homeserver will serve clients on, e.g. https://matrix.org
|
||||||
|
base_url: http://localhost
|
||||||
|
# Randomly generated string (e.g. by using "pwgen -sy 32") to be used to calculate the HMAC
|
||||||
|
form_secret: "superSecretRandomlyGeneratedSecret"
|
||||||
|
# Require consent when user registers for the first time
|
||||||
|
require_at_registration: false
|
||||||
|
# The name to be shown to the user
|
||||||
|
policy_name: "Privacy policy"
|
||||||
|
# The directory to search for templates
|
||||||
|
template_dir: "./templates/privacy"
|
||||||
|
# The version of the policy. When loading templates, ".gohtml" template is added as a suffix
|
||||||
|
# e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to "1.0"
|
||||||
|
version: "1.0"
|
||||||
|
# Send a consent message to guest users
|
||||||
|
send_server_notice_to_guest: false
|
||||||
|
# Default message to send to users
|
||||||
|
server_notice_content:
|
||||||
|
msg_type: "m.text"
|
||||||
|
body: >-
|
||||||
|
Please give your consent to the privacy policy at {{ .ConsentURL }}.
|
||||||
|
# The error message to display if the user hasn't given their consent yet
|
||||||
|
block_events_error: >-
|
||||||
|
You can't send any messages until you consent to the privacy policy at
|
||||||
|
{{ .ConsentURL }}.
|
||||||
|
|
||||||
# Configuration for NATS JetStream
|
# Configuration for NATS JetStream
|
||||||
jetstream:
|
jetstream:
|
||||||
# A list of NATS Server addresses to connect to. If none are specified, an
|
# A list of NATS Server addresses to connect to. If none are specified, an
|
||||||
|
|
26
docs/templates/privacy/1.0.gohtml
vendored
Normal file
26
docs/templates/privacy/1.0.gohtml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Privacy policy</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{{ if .HasConsented }}
|
||||||
|
<p>
|
||||||
|
You have already given your consent.
|
||||||
|
</p>
|
||||||
|
{{ else }}
|
||||||
|
<p>
|
||||||
|
Please give your consent to keep using this homeserver.
|
||||||
|
</p>
|
||||||
|
{{ if not .ReadOnly }}
|
||||||
|
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
|
||||||
|
<form method="post" action="consent">
|
||||||
|
<input type="hidden" name="v" value="{{ .Version }}"/>
|
||||||
|
<input type="hidden" name="u" value="{{ .UserID }}"/>
|
||||||
|
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
|
||||||
|
<input type="submit" value="I consent"/>
|
||||||
|
</form>
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -15,6 +15,8 @@
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -25,9 +27,12 @@ import (
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
opentracing "github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
"github.com/opentracing/opentracing-go/ext"
|
"github.com/opentracing/opentracing-go/ext"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
@ -41,10 +46,24 @@ type BasicAuth struct {
|
||||||
Password string `yaml:"password"`
|
Password string `yaml:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthAPICheck is an option to MakeAuthAPI to add additional checks (e.g. WithConsentCheck) to verify
|
||||||
|
// the user is allowed to do specific things.
|
||||||
|
type AuthAPICheck func(ctx context.Context, device *userapi.Device) *util.JSONResponse
|
||||||
|
|
||||||
|
// WithConsentCheck checks that a user has given his consent.
|
||||||
|
func WithConsentCheck(options config.UserConsentOptions, api userapi.QueryPolicyVersionAPI) AuthAPICheck {
|
||||||
|
return func(ctx context.Context, device *userapi.Device) *util.JSONResponse {
|
||||||
|
if !options.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return checkConsent(ctx, device.UserID, api, options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
||||||
func MakeAuthAPI(
|
func MakeAuthAPI(
|
||||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPICheck,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
h := func(req *http.Request) util.JSONResponse {
|
h := func(req *http.Request) util.JSONResponse {
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
|
@ -72,6 +91,14 @@ func MakeAuthAPI(
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// apply additional checks, if any
|
||||||
|
for _, opt := range checks {
|
||||||
|
resp := opt(req.Context(), device)
|
||||||
|
if resp != nil {
|
||||||
|
return *resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
jsonRes := f(req, device)
|
jsonRes := f(req, device)
|
||||||
// do not log 4xx as errors as they are client fails, not server fails
|
// do not log 4xx as errors as they are client fails, not server fails
|
||||||
if hub != nil && jsonRes.Code >= 500 {
|
if hub != nil && jsonRes.Code >= 500 {
|
||||||
|
@ -83,6 +110,53 @@ func MakeAuthAPI(
|
||||||
return MakeExternalAPI(metricsName, h)
|
return MakeExternalAPI(metricsName, h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkConsent(ctx context.Context, userID string, userAPI userapi.QueryPolicyVersionAPI, userConsentCfg config.UserConsentOptions) *util.JSONResponse {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// check which version of the policy the user accepted
|
||||||
|
res := &userapi.QueryPolicyVersionResponse{}
|
||||||
|
err = userAPI.QueryPolicyVersion(ctx, &userapi.QueryPolicyVersionRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
}, res)
|
||||||
|
if err != nil {
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("unable to get policy version"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// user hasn't accepted any policy, block access.
|
||||||
|
if userConsentCfg.Version != res.PolicyVersion {
|
||||||
|
uri, err := userConsentCfg.ConsentURL(userID)
|
||||||
|
if err != nil {
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("unable to get consent URL"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg := &bytes.Buffer{}
|
||||||
|
c := struct {
|
||||||
|
ConsentURL string
|
||||||
|
}{
|
||||||
|
ConsentURL: uri,
|
||||||
|
}
|
||||||
|
if err = userConsentCfg.TextTemplates.ExecuteTemplate(msg, "blockEventsError", c); err != nil {
|
||||||
|
logrus.Infof("error consent message: %+v", err)
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("unable to execute template"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.ConsentNotGiven(uri, msg.String()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
|
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
|
||||||
// This is used for APIs that are called from the internet.
|
// This is used for APIs that are called from the internet.
|
||||||
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
|
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
|
||||||
|
|
|
@ -59,15 +59,12 @@ func Setup(
|
||||||
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadHandler := httputil.MakeAuthAPI(
|
uploadHandler := httputil.MakeAuthAPI("upload", userAPI, func(req *http.Request, dev *userapi.Device) util.JSONResponse {
|
||||||
"upload", userAPI,
|
|
||||||
func(req *http.Request, dev *userapi.Device) util.JSONResponse {
|
|
||||||
if r := rateLimits.Limit(req); r != nil {
|
if r := rateLimits.Limit(req); r != nil {
|
||||||
return *r
|
return *r
|
||||||
}
|
}
|
||||||
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
|
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
|
||||||
},
|
})
|
||||||
)
|
|
||||||
|
|
||||||
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req); r != nil {
|
if r := rateLimits.Limit(req); r != nil {
|
||||||
|
|
|
@ -268,6 +268,7 @@ func (b *BaseDendrite) Close() error {
|
||||||
func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) {
|
func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) {
|
||||||
if dbProperties.ConnectionString != "" || b == nil {
|
if dbProperties.ConnectionString != "" || b == nil {
|
||||||
// Open a new database connection using the supplied config.
|
// Open a new database connection using the supplied config.
|
||||||
|
logrus.Infof("Open a new database connection using the supplied config.: %+v", dbProperties.ConnectionString)
|
||||||
db, err := sqlutil.Open(dbProperties, writer)
|
db, err := sqlutil.Open(dbProperties, writer)
|
||||||
return db, writer, err
|
return db, writer, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -265,6 +265,21 @@ func loadConfig(
|
||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Terms struct {
|
||||||
|
Policies Policies `json:"policies"`
|
||||||
|
}
|
||||||
|
type En struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
type PrivacyPolicy struct {
|
||||||
|
En En `json:"en"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
type Policies struct {
|
||||||
|
PrivacyPolicy PrivacyPolicy `json:"privacy_policy"`
|
||||||
|
}
|
||||||
|
|
||||||
// Derive generates data that is derived from various values provided in
|
// Derive generates data that is derived from various values provided in
|
||||||
// the config file.
|
// the config file.
|
||||||
func (config *Dendrite) Derive() error {
|
func (config *Dendrite) Derive() error {
|
||||||
|
@ -275,13 +290,39 @@ func (config *Dendrite) Derive() error {
|
||||||
// TODO: Add email auth type
|
// TODO: Add email auth type
|
||||||
// TODO: Add MSISDN auth type
|
// TODO: Add MSISDN auth type
|
||||||
|
|
||||||
|
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
|
||||||
|
uri := config.Global.UserConsentOptions.BaseURL + "/_matrix/client/consent?v=" + config.Global.UserConsentOptions.Version
|
||||||
|
config.Derived.Registration.Params[authtypes.LoginTypeTerms] = Terms{
|
||||||
|
Policies: Policies{
|
||||||
|
PrivacyPolicy: PrivacyPolicy{
|
||||||
|
En: En{
|
||||||
|
Name: config.Global.UserConsentOptions.PolicyName,
|
||||||
|
URL: uri,
|
||||||
|
},
|
||||||
|
Version: config.Global.UserConsentOptions.Version,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
if config.ClientAPI.RecaptchaEnabled {
|
if config.ClientAPI.RecaptchaEnabled {
|
||||||
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
}
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
|
|
||||||
} else {
|
if config.Derived.Registration.Flows == nil {
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
|
Stages: []authtypes.LoginType{authtypes.LoginTypeDummy},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepend each flow with LoginTypeTerms or LoginTypeRecaptcha
|
||||||
|
for i, flow := range config.Derived.Registration.Flows {
|
||||||
|
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
|
||||||
|
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeTerms}, flow.Stages...)
|
||||||
|
}
|
||||||
|
if config.ClientAPI.RecaptchaEnabled {
|
||||||
|
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeRecaptcha}, flow.Stages...)
|
||||||
|
}
|
||||||
|
config.Derived.Registration.Flows[i] = flow
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load application service configuration files
|
// Load application service configuration files
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
textTemplate "text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -71,6 +79,9 @@ type Global struct {
|
||||||
// ServerNotices configuration used for sending server notices
|
// ServerNotices configuration used for sending server notices
|
||||||
ServerNotices ServerNotices `yaml:"server_notices"`
|
ServerNotices ServerNotices `yaml:"server_notices"`
|
||||||
|
|
||||||
|
// Consent tracking options
|
||||||
|
UserConsentOptions UserConsentOptions `yaml:"user_consent"`
|
||||||
|
|
||||||
// ReportStats configures opt-in anonymous stats reporting.
|
// ReportStats configures opt-in anonymous stats reporting.
|
||||||
ReportStats ReportStats `yaml:"report_stats"`
|
ReportStats ReportStats `yaml:"report_stats"`
|
||||||
}
|
}
|
||||||
|
@ -88,6 +99,7 @@ func (c *Global) Defaults(generate bool) {
|
||||||
c.Metrics.Defaults(generate)
|
c.Metrics.Defaults(generate)
|
||||||
c.DNSCache.Defaults()
|
c.DNSCache.Defaults()
|
||||||
c.Sentry.Defaults()
|
c.Sentry.Defaults()
|
||||||
|
c.UserConsentOptions.Defaults()
|
||||||
c.ServerNotices.Defaults(generate)
|
c.ServerNotices.Defaults(generate)
|
||||||
c.ReportStats.Defaults()
|
c.ReportStats.Defaults()
|
||||||
}
|
}
|
||||||
|
@ -100,6 +112,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
c.Metrics.Verify(configErrs, isMonolith)
|
c.Metrics.Verify(configErrs, isMonolith)
|
||||||
c.Sentry.Verify(configErrs, isMonolith)
|
c.Sentry.Verify(configErrs, isMonolith)
|
||||||
c.DNSCache.Verify(configErrs, isMonolith)
|
c.DNSCache.Verify(configErrs, isMonolith)
|
||||||
|
c.UserConsentOptions.Verify(configErrs, isMonolith)
|
||||||
c.ServerNotices.Verify(configErrs, isMonolith)
|
c.ServerNotices.Verify(configErrs, isMonolith)
|
||||||
c.ReportStats.Verify(configErrs, isMonolith)
|
c.ReportStats.Verify(configErrs, isMonolith)
|
||||||
}
|
}
|
||||||
|
@ -261,6 +274,102 @@ func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime))
|
checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Consent tracking configuration
|
||||||
|
// If either require_at_registration or send_server_notice_to_guest are true, consent
|
||||||
|
// messages will be sent to the users.
|
||||||
|
type UserConsentOptions struct {
|
||||||
|
// If consent tracking is enabled or not
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
// Randomly generated string to be used to calculate the HMAC
|
||||||
|
FormSecret string `yaml:"form_secret"`
|
||||||
|
// Require consent when user registers for the first time
|
||||||
|
RequireAtRegistration bool `yaml:"require_at_registration"`
|
||||||
|
// The name to be shown to the user
|
||||||
|
PolicyName string `yaml:"policy_name"`
|
||||||
|
// The directory to search for *.gohtml templates
|
||||||
|
TemplateDir string `yaml:"template_dir"`
|
||||||
|
// The version of the policy. When loading templates, ".gohtml" template is added as a suffix
|
||||||
|
// e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to 1.0
|
||||||
|
Version string `yaml:"version"`
|
||||||
|
// Send a consent message to guest users
|
||||||
|
SendServerNoticeToGuest bool `yaml:"send_server_notice_to_guest"`
|
||||||
|
// Default message to send to users
|
||||||
|
ServerNoticeContent struct {
|
||||||
|
MsgType string `yaml:"msg_type"`
|
||||||
|
Body string `yaml:"body"`
|
||||||
|
} `yaml:"server_notice_content"`
|
||||||
|
// The error message to display if the user hasn't given their consent yet
|
||||||
|
BlockEventsError string `yaml:"block_events_error"`
|
||||||
|
// All loaded templates
|
||||||
|
Templates *template.Template `yaml:"-"`
|
||||||
|
TextTemplates *textTemplate.Template `yaml:"-"`
|
||||||
|
// The base URL this homeserver will serve clients on, e.g. https://matrix.org
|
||||||
|
BaseURL string `yaml:"base_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConsentOptions) Defaults() {
|
||||||
|
c.Enabled = false
|
||||||
|
c.RequireAtRegistration = false
|
||||||
|
c.SendServerNoticeToGuest = false
|
||||||
|
c.PolicyName = "Privacy Policy"
|
||||||
|
c.Version = "1.0"
|
||||||
|
c.TemplateDir = "./templates/privacy"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConsentOptions) Verify(configErrors *ConfigErrors, isMonolith bool) {
|
||||||
|
if !c.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checkNotEmpty(configErrors, "template_dir", c.TemplateDir)
|
||||||
|
checkNotEmpty(configErrors, "version", c.Version)
|
||||||
|
checkNotEmpty(configErrors, "policy_name", c.PolicyName)
|
||||||
|
checkNotEmpty(configErrors, "form_secret", c.FormSecret)
|
||||||
|
checkNotEmpty(configErrors, "base_url", c.BaseURL)
|
||||||
|
if len(*configErrors) > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := filepath.Abs(c.TemplateDir)
|
||||||
|
if err != nil {
|
||||||
|
configErrors.Add("unable to get template directory")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.TextTemplates = textTemplate.Must(textTemplate.New("blockEventsError").Parse(c.BlockEventsError))
|
||||||
|
c.TextTemplates = textTemplate.Must(c.TextTemplates.New("serverNoticeTemplate").Parse(c.ServerNoticeContent.Body))
|
||||||
|
|
||||||
|
// Read all defined *.gohtml templates
|
||||||
|
t, err := template.ParseGlob(filepath.Join(p, "*.gohtml"))
|
||||||
|
if err != nil || t == nil {
|
||||||
|
configErrors.Add(fmt.Sprintf("unable to read consent templates: %+v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Templates = t
|
||||||
|
// Verify we've got a template for the defined version
|
||||||
|
versionTemplate := c.Templates.Lookup(c.Version + ".gohtml")
|
||||||
|
if versionTemplate == nil {
|
||||||
|
configErrors.Add(fmt.Sprintf("unable to load defined '%s' policy template", c.Version))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsentURL constructs the URL shown to users to accept the TOS
|
||||||
|
func (c *UserConsentOptions) ConsentURL(userID string) (string, error) {
|
||||||
|
mac := hmac.New(sha256.New, []byte(c.FormSecret))
|
||||||
|
_, err := mac.Write([]byte(userID))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
userMAC := hex.EncodeToString(mac.Sum(nil))
|
||||||
|
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("u", userID)
|
||||||
|
params.Add("h", userMAC)
|
||||||
|
params.Add("v", c.Version)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s/_matrix/client/consent?%s", c.BaseURL, params.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
// PresenceOptions defines possible configurations for presence events.
|
// PresenceOptions defines possible configurations for presence events.
|
||||||
type PresenceOptions struct {
|
type PresenceOptions struct {
|
||||||
// Whether inbound presence events are allowed
|
// Whether inbound presence events are allowed
|
||||||
|
|
115
setup/config/config_global_test.go
Normal file
115
setup/config/config_global_test.go
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserConsentOptions_Verify(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields UserConsentOptions
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "template dir not set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "template dir set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
TemplateDir: "testdata/privacy",
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "policy name not set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
TemplateDir: "testdata/privacy",
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "policy name set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
TemplateDir: "testdata/privacy",
|
||||||
|
PolicyName: "Privacy policy",
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "version not set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
TemplateDir: "testdata/privacy",
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "everyhing required set",
|
||||||
|
fields: UserConsentOptions{
|
||||||
|
RequireAtRegistration: true,
|
||||||
|
TemplateDir: "./testdata/privacy",
|
||||||
|
Version: "1.0",
|
||||||
|
PolicyName: "Privacy policy",
|
||||||
|
FormSecret: "helloWorld",
|
||||||
|
BaseURL: "http://localhost",
|
||||||
|
},
|
||||||
|
args: struct {
|
||||||
|
configErrors *ConfigErrors
|
||||||
|
isMonolith bool
|
||||||
|
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &UserConsentOptions{
|
||||||
|
Enabled: true,
|
||||||
|
BaseURL: tt.fields.BaseURL,
|
||||||
|
FormSecret: tt.fields.FormSecret,
|
||||||
|
RequireAtRegistration: tt.fields.RequireAtRegistration,
|
||||||
|
PolicyName: tt.fields.PolicyName,
|
||||||
|
Version: tt.fields.Version,
|
||||||
|
TemplateDir: tt.fields.TemplateDir,
|
||||||
|
SendServerNoticeToGuest: tt.fields.SendServerNoticeToGuest,
|
||||||
|
ServerNoticeContent: tt.fields.ServerNoticeContent,
|
||||||
|
BlockEventsError: tt.fields.BlockEventsError,
|
||||||
|
}
|
||||||
|
c.Verify(tt.args.configErrors, tt.args.isMonolith)
|
||||||
|
if !tt.wantErr && len(*tt.args.configErrors) > 0 {
|
||||||
|
t.Errorf("expected no errors, got '%+v'", tt.args.configErrors)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
26
setup/config/testdata/privacy/1.0.gohtml
vendored
Normal file
26
setup/config/testdata/privacy/1.0.gohtml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Privacy policy</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{{ if .HasConsented }}
|
||||||
|
<p>
|
||||||
|
You have already given your consent.
|
||||||
|
</p>
|
||||||
|
{{ else }}
|
||||||
|
<p>
|
||||||
|
Please give your consent to keep using this homeserver.
|
||||||
|
</p>
|
||||||
|
{{ if not .ReadOnly }}
|
||||||
|
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
|
||||||
|
<form method="post" action="consent">
|
||||||
|
<input type="hidden" name="v" value="{{ .Version }}"/>
|
||||||
|
<input type="hidden" name="u" value="{{ .UserID }}"/>
|
||||||
|
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
|
||||||
|
<input type="submit" value="I consent"/>
|
||||||
|
</form>
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -44,7 +44,6 @@ func ParseFlags(monolith bool) *config.Dendrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := config.Load(*configPath, monolith)
|
cfg, err := config.Load(*configPath, monolith)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalf("Invalid config file: %s", err)
|
logrus.Fatalf("Invalid config file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,6 +93,6 @@ func Setup(
|
||||||
vars["roomId"], vars["eventId"],
|
vars["roomId"], vars["eventId"],
|
||||||
lazyLoadCache,
|
lazyLoadCache,
|
||||||
)
|
)
|
||||||
}),
|
}, httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,6 +66,7 @@ type FederationUserAPI interface {
|
||||||
// api functions required by the sync api
|
// api functions required by the sync api
|
||||||
type SyncUserAPI interface {
|
type SyncUserAPI interface {
|
||||||
QueryAcccessTokenAPI
|
QueryAcccessTokenAPI
|
||||||
|
QueryPolicyVersionAPI
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
||||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
|
@ -78,6 +79,7 @@ type ClientUserAPI interface {
|
||||||
QueryAcccessTokenAPI
|
QueryAcccessTokenAPI
|
||||||
LoginTokenInternalAPI
|
LoginTokenInternalAPI
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
|
UserConsentPolicyAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
|
@ -106,6 +108,18 @@ type ClientUserAPI interface {
|
||||||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||||
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
|
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
|
||||||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserConsentPolicyAPI interface {
|
||||||
|
QueryPolicyVersionAPI
|
||||||
|
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
|
||||||
|
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryPolicyVersionAPI interface {
|
||||||
|
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// custom api functions required by pinecone / p2p demos
|
// custom api functions required by pinecone / p2p demos
|
||||||
|
@ -320,7 +334,7 @@ type QuerySearchProfilesResponse struct {
|
||||||
type PerformAccountCreationRequest struct {
|
type PerformAccountCreationRequest struct {
|
||||||
AccountType AccountType // Required: whether this is a guest or user account
|
AccountType AccountType // Required: whether this is a guest or user account
|
||||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||||
|
PolicyVersion string // optional: the privacy policy this account has accepted
|
||||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||||
Password string // optional: if missing then this account will be a passwordless account
|
Password string // optional: if missing then this account will be a passwordless account
|
||||||
OnConflict Conflict
|
OnConflict Conflict
|
||||||
|
@ -412,6 +426,53 @@ type QueryOpenIDTokenResponse struct {
|
||||||
ExpiresAtMS int64
|
ExpiresAtMS int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
|
||||||
|
type QueryPolicyVersionRequest struct {
|
||||||
|
Localpart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
|
||||||
|
type QueryPolicyVersionResponse struct {
|
||||||
|
PolicyVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOutdatedPolicyRequest is the request for QueryOutdatedPolicyRequest
|
||||||
|
type QueryOutdatedPolicyRequest struct {
|
||||||
|
PolicyVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
|
||||||
|
type QueryOutdatedPolicyResponse struct {
|
||||||
|
UserLocalparts []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||||
|
type UpdatePolicyVersionRequest struct {
|
||||||
|
PolicyVersion, Localpart string
|
||||||
|
ServerNoticeUpdate bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||||
|
type UpdatePolicyVersionResponse struct{}
|
||||||
|
|
||||||
|
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
|
||||||
|
type QueryServerNoticeRoomRequest struct {
|
||||||
|
Localpart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
|
||||||
|
type QueryServerNoticeRoomResponse struct {
|
||||||
|
RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
|
||||||
|
type UpdateServerNoticeRoomRequest struct {
|
||||||
|
Localpart, RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
|
||||||
|
type UpdateServerNoticeRoomResponse struct{}
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
|
|
|
@ -203,6 +203,36 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error {
|
||||||
|
err := t.Impl.QueryPolicyVersion(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error {
|
||||||
|
err := t.Impl.QueryOutdatedPolicy(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryOutdatedPolicy req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
|
||||||
|
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
|
||||||
|
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
|
||||||
|
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
@ -833,3 +833,60 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
|
||||||
}
|
}
|
||||||
|
|
||||||
const pushRulesAccountDataType = "m.push_rules"
|
const pushRulesAccountDataType = "m.push_rules"
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryPolicyVersion(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryPolicyVersionRequest,
|
||||||
|
res *api.QueryPolicyVersionResponse,
|
||||||
|
) error {
|
||||||
|
var err error
|
||||||
|
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryOutdatedPolicy(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryOutdatedPolicyRequest,
|
||||||
|
res *api.QueryOutdatedPolicyResponse,
|
||||||
|
) error {
|
||||||
|
var err error
|
||||||
|
res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.UpdatePolicyVersionRequest,
|
||||||
|
res *api.UpdatePolicyVersionResponse,
|
||||||
|
) error {
|
||||||
|
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryServerNoticeRoomRequest,
|
||||||
|
res *api.QueryServerNoticeRoomResponse,
|
||||||
|
) (err error) {
|
||||||
|
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.RoomID = roomID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.UpdateServerNoticeRoomRequest,
|
||||||
|
res *api.UpdateServerNoticeRoomResponse,
|
||||||
|
) (err error) {
|
||||||
|
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
|
||||||
|
}
|
||||||
|
|
|
@ -44,6 +44,8 @@ const (
|
||||||
PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
|
PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
|
||||||
PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
|
PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
|
||||||
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
|
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
|
||||||
|
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
||||||
|
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
|
||||||
|
|
||||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
|
@ -61,6 +63,9 @@ const (
|
||||||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||||
|
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
||||||
|
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
||||||
|
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -391,3 +396,43 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context
|
||||||
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
|
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryOutdatedPolicy(ctx context.Context, req *api.QueryOutdatedPolicyRequest, res *api.QueryOutdatedPolicyResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOutdatedPolicy")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUpdatePolicyVersion")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryServerNoticeRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryPolicyVersionPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
@ -457,4 +457,74 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryPolicyVersionPath,
|
||||||
|
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryPolicyVersionRequest{}
|
||||||
|
response := api.QueryPolicyVersionResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.QueryPolicyVersion(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
|
||||||
|
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryOutdatedPolicyRequest{}
|
||||||
|
response := api.QueryOutdatedPolicyResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.QueryOutdatedPolicy(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
|
||||||
|
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.UpdatePolicyVersionRequest{}
|
||||||
|
response := api.UpdatePolicyVersionResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(QueryServerNoticeRoomPath,
|
||||||
|
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryServerNoticeRoomRequest{}
|
||||||
|
response := api.QueryServerNoticeRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
|
||||||
|
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.UpdateServerNoticeRoomRequest{}
|
||||||
|
response := api.UpdateServerNoticeRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ type Account interface {
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||||
|
@ -126,9 +126,18 @@ type Notification interface {
|
||||||
DeleteOldNotifications(ctx context.Context) error
|
DeleteOldNotifications(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ConsentTracking interface {
|
||||||
|
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||||
|
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||||
|
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
Account
|
Account
|
||||||
AccountData
|
AccountData
|
||||||
|
ConsentTracking
|
||||||
Device
|
Device
|
||||||
KeyBackup
|
KeyBackup
|
||||||
LoginToken
|
LoginToken
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
is_deactivated BOOLEAN DEFAULT FALSE,
|
is_deactivated BOOLEAN DEFAULT FALSE,
|
||||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||||
account_type SMALLINT NOT NULL
|
account_type SMALLINT NOT NULL,
|
||||||
|
-- The policy version this user has accepted
|
||||||
|
policy_version TEXT,
|
||||||
|
-- The policy version the user received from the server notices room
|
||||||
|
policy_version_sent TEXT,
|
||||||
|
server_notice_room_id TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
@ -67,6 +73,24 @@ const selectPasswordHashSQL = "" +
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
||||||
|
|
||||||
|
const selectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
|
||||||
|
|
||||||
|
const updatePolicyVersionSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const selectServerNoticeRoomSQL = "" +
|
||||||
|
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateServerNoticeRoomSQL = "" +
|
||||||
|
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
updatePasswordStmt *sql.Stmt
|
updatePasswordStmt *sql.Stmt
|
||||||
|
@ -74,6 +98,12 @@ type accountsStatements struct {
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
|
selectServerNoticeRoomStmt *sql.Stmt
|
||||||
|
updateServerNoticeRoomStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,6 +122,12 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
|
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||||
|
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,16 +135,16 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||||
} else {
|
} else {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -178,3 +214,71 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return id + 1, err
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
|
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
|
) (policy string, err error) {
|
||||||
|
var policyNull sql.NullString
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
|
||||||
|
return policyNull.String, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
|
) (userIDs []string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, policyVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return userIDs, err
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return userIDs, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if serverNotice {
|
||||||
|
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||||
|
}
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||||
|
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
|
) (roomID string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||||
|
|
||||||
|
roomIDNull := sql.NullString{}
|
||||||
|
row := stmt.QueryRowContext(ctx, localpart)
|
||||||
|
err = row.Scan(&roomIDNull)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// roomIDNull.String is either the roomID or an empty string
|
||||||
|
return roomIDNull.String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||||
|
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
func LoadFromGoose() {
|
func LoadFromGoose() {
|
||||||
goose.AddMigration(UpIsActive, DownIsActive)
|
goose.AddMigration(UpIsActive, DownIsActive)
|
||||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||||
|
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadIsActive(m *sqlutil.Migrations) {
|
func LoadIsActive(m *sqlutil.Migrations) {
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -46,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
|
deltas.LoadAddPolicyVersion(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/userapi/types"
|
"github.com/matrix-org/dendrite/userapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
@ -125,7 +126,7 @@ func (d *Database) SetPassword(
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
|
||||||
) (acc *api.Account, err error) {
|
) (acc *api.Account, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
// For guest accounts, we create a new numeric local part
|
// For guest accounts, we create a new numeric local part
|
||||||
|
@ -139,7 +140,7 @@ func (d *Database) CreateAccount(
|
||||||
plaintextPassword = ""
|
plaintextPassword = ""
|
||||||
appserviceID = ""
|
appserviceID = ""
|
||||||
}
|
}
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -148,7 +149,7 @@ func (d *Database) CreateAccount(
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var err error
|
var err error
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
|
@ -160,7 +161,8 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil {
|
||||||
|
logrus.WithError(err).Error("d.Accounts.InsertAccount error")
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||||
|
@ -763,3 +765,42 @@ func (d *Database) RemovePushers(
|
||||||
func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) {
|
func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) {
|
||||||
return d.Stats.UserStatistics(ctx, nil)
|
return d.Stats.UserStatistics(ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||||
|
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||||
|
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID returns the server notice room, if one is set.
|
||||||
|
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
|
||||||
|
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID updates the server notice room
|
||||||
|
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
is_deactivated BOOLEAN DEFAULT 0,
|
is_deactivated BOOLEAN DEFAULT 0,
|
||||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||||
account_type INTEGER NOT NULL
|
account_type INTEGER NOT NULL,
|
||||||
|
-- The policy version this user has accepted
|
||||||
|
policy_version TEXT,
|
||||||
|
-- The policy version the user received from the server notices room
|
||||||
|
policy_version_sent TEXT,
|
||||||
|
server_notice_room_id TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
@ -67,6 +73,24 @@ const selectPasswordHashSQL = "" +
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||||
|
|
||||||
|
const selectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
|
||||||
|
|
||||||
|
const updatePolicyVersionSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const selectServerNoticeRoomSQL = "" +
|
||||||
|
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateServerNoticeRoomSQL = "" +
|
||||||
|
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -75,6 +99,12 @@ type accountsStatements struct {
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
|
selectServerNoticeRoomStmt *sql.Stmt
|
||||||
|
updateServerNoticeRoomStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +124,12 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
|
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||||
|
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,16 +137,16 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||||
} else {
|
} else {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -183,3 +219,72 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
}
|
}
|
||||||
return id + 1, err
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
|
|
||||||
|
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
|
) (policy string, err error) {
|
||||||
|
var policyNull sql.NullString
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
|
||||||
|
return policyNull.String, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
|
) (userIDs []string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return userIDs, err
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return userIDs, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if serverNotice {
|
||||||
|
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||||
|
}
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||||
|
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
|
) (roomID string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||||
|
|
||||||
|
roomIDNull := sql.NullString{}
|
||||||
|
row := stmt.QueryRowContext(ctx, localpart)
|
||||||
|
err = row.Scan(&roomIDNull)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// roomIDNull.String is either the roomID or an empty string
|
||||||
|
return roomIDNull.String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||||
|
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
func LoadFromGoose() {
|
func LoadFromGoose() {
|
||||||
goose.AddMigration(UpIsActive, DownIsActive)
|
goose.AddMigration(UpIsActive, DownIsActive)
|
||||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||||
|
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadIsActive(m *sqlutil.Migrations) {
|
func LoadIsActive(m *sqlutil.Migrations) {
|
||||||
|
|
|
@ -4,15 +4,9 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/pressly/goose"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadAddAccountType(m *sqlutil.Migrations) {
|
func LoadAddAccountType(m *sqlutil.Migrations) {
|
||||||
m.AddMigration(UpAddAccountType, DownAddAccountType)
|
m.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -18,11 +18,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||||
|
@ -47,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
|
deltas.LoadAddPolicyVersion(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,7 @@ func Test_Accounts(t *testing.T) {
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
// verify the newly create account is the same as returned by CreateAccount
|
// verify the newly create account is the same as returned by CreateAccount
|
||||||
var accGet *api.Account
|
var accGet *api.Account
|
||||||
|
@ -102,7 +102,7 @@ func Test_Accounts(t *testing.T) {
|
||||||
first, err := db.GetNewNumericLocalpart(ctx)
|
first, err := db.GetNewNumericLocalpart(ctx)
|
||||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||||
// Create a new account to verify the numeric localpart is updated
|
// Create a new account to verify the numeric localpart is updated
|
||||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
_, err = db.CreateAccount(ctx, "", "testing", "", "v1.0", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
second, err := db.GetNewNumericLocalpart(ctx)
|
second, err := db.GetNewNumericLocalpart(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -350,7 +350,7 @@ func Test_Profile(t *testing.T) {
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
// create account, which also creates a profile
|
// create account, which also creates a profile
|
||||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
|
||||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||||
|
|
|
@ -32,12 +32,18 @@ type AccountDataTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsTable interface {
|
type AccountsTable interface {
|
||||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
|
|
||||||
|
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||||
|
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||||
|
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
|
|
|
@ -88,7 +88,7 @@ func mustMakeAccountAndDevice(
|
||||||
appServiceID = util.RandomString(16)
|
appServiceID = util.RandomString(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
_, err := accDB.InsertAccount(ctx, nil, localpart, "", "", appServiceID, accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create account: %v", err)
|
t.Fatalf("unable to create account: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue