Compare commits
82 commits
main
...
s7evink/co
Author | SHA1 | Date | |
---|---|---|---|
cabc5f4bc9 | |||
7df7a966b8 | |||
75ca5490bc | |||
c9409078ac | |||
3d3773d3d4 | |||
4bd9a73c13 | |||
4d285fff60 | |||
10dc02f1ea | |||
94ed2d3689 | |||
964e1cef85 | |||
2eb3aab07e | |||
60ba4b5612 | |||
88612ddd0c | |||
bddf8ed3ac | |||
4d5feb2544 | |||
cd7a7606a1 | |||
dc8cea6d57 | |||
e0cdf64c33 | |||
ef62255685 | |||
1f64fc79c8 | |||
2b496be2c3 | |||
733b601aa9 | |||
2a18023a1a | |||
c99e3aff1b | |||
f1e8d19cea | |||
019f0922ea | |||
6324c1d01f | |||
b9479a6f18 | |||
e42ef1706b | |||
31ac3ac081 | |||
39d9d88b02 | |||
710007d600 | |||
dcfc0bcd43 | |||
c7d2254698 | |||
7c6a162c0f | |||
699617ee4d | |||
519ea13510 | |||
fa26aa9138 | |||
e80ca307d3 | |||
df7218e230 | |||
e6e62497c9 | |||
2ad15f308f | |||
ed16a2f107 | |||
ce658ab8f2 | |||
79e1c9e4bd | |||
2042303c6c | |||
4f2d161401 | |||
c65eb2bf52 | |||
0ae8293abd | |||
dac29c1786 | |||
c2b6019c35 | |||
61cdb714df | |||
185cb7a582 | |||
e2b0ff675b | |||
c0845ea1ad | |||
6622fda08c | |||
219a15c4c3 | |||
fb95331aa2 | |||
cb4526793d | |||
2e6987f8bd | |||
9c3a1cfd47 | |||
74da1f0fb3 | |||
26accb8c5d | |||
6482630f7b | |||
2fc1c46743 | |||
5a0ec6e443 | |||
535d388ec0 | |||
cbdbbb0839 | |||
f8bebe5e5a | |||
d19518fca5 | |||
89340cfc52 | |||
11144de92f | |||
b2045c24cb | |||
097f1d4609 | |||
a505471c90 | |||
3c5c3ea7fb | |||
9583784e8a | |||
b6ee34918c | |||
ac343861ad | |||
4da7df5e3e | |||
ccc11f94f7 | |||
5702b84dae |
|
@ -11,4 +11,5 @@ const (
|
|||
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||
LoginTypeApplicationService = "m.login.application_service"
|
||||
LoginTypeToken = "m.login.token"
|
||||
LoginTypeTerms = "m.login.terms"
|
||||
)
|
||||
|
|
|
@ -29,6 +29,13 @@ type MatrixError struct {
|
|||
Err string `json:"error"`
|
||||
}
|
||||
|
||||
// ConsentError is an error returned to users, who didn't accept the
|
||||
// TOS of this server yet.
|
||||
type ConsentError struct {
|
||||
MatrixError
|
||||
ConsentURI string `json:"consent_uri"`
|
||||
}
|
||||
|
||||
func (e MatrixError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.ErrCode, e.Err)
|
||||
}
|
||||
|
@ -207,3 +214,15 @@ func NotTrusted(serverName string) *MatrixError {
|
|||
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
|
||||
}
|
||||
}
|
||||
|
||||
// ConsentNotGiven is an error returned to users, who didn't accept the
|
||||
// TOS of this server yet.
|
||||
func ConsentNotGiven(consentURI string, msg string) *ConsentError {
|
||||
return &ConsentError{
|
||||
MatrixError: MatrixError{
|
||||
ErrCode: "M_CONSENT_NOT_GIVEN",
|
||||
Err: msg,
|
||||
},
|
||||
ConsentURI: consentURI,
|
||||
}
|
||||
}
|
||||
|
|
219
clientapi/routing/consent_tracking.go
Normal file
219
clientapi/routing/consent_tracking.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// The data used to populate the /consent request
|
||||
type constentTemplateData struct {
|
||||
UserID string
|
||||
Version string
|
||||
UserHMAC string
|
||||
HasConsented bool
|
||||
ReadOnly bool
|
||||
}
|
||||
|
||||
func writeHeaderAndText(w http.ResponseWriter, statusCode int) {
|
||||
w.WriteHeader(statusCode)
|
||||
_, _ = w.Write([]byte(http.StatusText(statusCode)))
|
||||
}
|
||||
|
||||
func consent(writer http.ResponseWriter, req *http.Request, userAPI userapi.UserConsentPolicyAPI, cfg *config.ClientAPI) {
|
||||
consentCfg := cfg.Matrix.UserConsentOptions
|
||||
|
||||
// The data used to populate the /consent request
|
||||
data := constentTemplateData{
|
||||
UserID: req.FormValue("u"),
|
||||
Version: req.FormValue("v"),
|
||||
UserHMAC: req.FormValue("h"),
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
// display the privacy policy without a form
|
||||
data.ReadOnly = data.UserID == "" || data.UserHMAC == "" || data.Version == ""
|
||||
|
||||
// let's see if the user already consented to the current version
|
||||
if !data.ReadOnly {
|
||||
if ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret); err != nil || !ok {
|
||||
writeHeaderAndText(writer, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
res := &userapi.QueryPolicyVersionResponse{}
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("unable to split username")
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{
|
||||
Localpart: localpart,
|
||||
}, res); err != nil {
|
||||
logrus.WithError(err).Error("unable query policy version")
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
data.HasConsented = res.PolicyVersion == consentCfg.Version
|
||||
}
|
||||
|
||||
err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("unable to execute consent template")
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case http.MethodPost:
|
||||
ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret)
|
||||
if err != nil || !ok {
|
||||
if !ok {
|
||||
writeHeaderAndText(writer, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("unable to split username")
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err = userAPI.PerformUpdatePolicyVersion(
|
||||
req.Context(),
|
||||
&userapi.UpdatePolicyVersionRequest{
|
||||
PolicyVersion: data.Version,
|
||||
Localpart: localpart,
|
||||
},
|
||||
&userapi.UpdatePolicyVersionResponse{},
|
||||
); err != nil {
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// display the privacy policy without a form
|
||||
data.ReadOnly = false
|
||||
data.HasConsented = true
|
||||
|
||||
err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("unable to print consent template")
|
||||
writeHeaderAndText(writer, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendServerNoticeForConsent(userAPI userapi.ClientUserAPI, rsAPI api.ClientRoomserverAPI,
|
||||
cfgNotices *config.ServerNotices,
|
||||
cfgClient *config.ClientAPI,
|
||||
senderDevice *userapi.Device,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
) {
|
||||
res := &userapi.QueryOutdatedPolicyResponse{}
|
||||
if err := userAPI.QueryOutdatedPolicy(context.Background(), &userapi.QueryOutdatedPolicyRequest{
|
||||
PolicyVersion: cfgClient.Matrix.UserConsentOptions.Version,
|
||||
}, res); err != nil {
|
||||
logrus.WithError(err).Error("unable to fetch users with outdated consent policy")
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
consentOpts = cfgClient.Matrix.UserConsentOptions
|
||||
data = make(map[string]string)
|
||||
err error
|
||||
sentMessages int
|
||||
)
|
||||
|
||||
if len(res.UserLocalparts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("count", len(res.UserLocalparts)).Infof("Sending server notice to users who have not yet accepted the policy")
|
||||
|
||||
for _, localpart := range res.UserLocalparts {
|
||||
if localpart == cfgClient.Matrix.ServerNotices.LocalPart {
|
||||
continue
|
||||
}
|
||||
userID := fmt.Sprintf("@%s:%s", localpart, cfgClient.Matrix.ServerName)
|
||||
data["ConsentURL"], err = consentOpts.ConsentURL(userID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("userID", userID).Error("unable to construct consentURI")
|
||||
continue
|
||||
}
|
||||
msgBody := &bytes.Buffer{}
|
||||
|
||||
if err = consentOpts.TextTemplates.ExecuteTemplate(msgBody, "serverNoticeTemplate", data); err != nil {
|
||||
logrus.WithError(err).WithField("userID", userID).Error("unable to execute serverNoticeTemplate")
|
||||
continue
|
||||
}
|
||||
|
||||
req := sendServerNoticeRequest{
|
||||
UserID: userID,
|
||||
Content: struct {
|
||||
MsgType string `json:"msgtype,omitempty"`
|
||||
Body string `json:"body,omitempty"`
|
||||
}{
|
||||
MsgType: consentOpts.ServerNoticeContent.MsgType,
|
||||
Body: msgBody.String(),
|
||||
},
|
||||
}
|
||||
_, err = sendServerNotice(context.Background(), req, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, nil, nil, nil)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("userID", userID).Error("failed to send server notice for consent to user")
|
||||
continue
|
||||
}
|
||||
sentMessages++
|
||||
res := &userapi.UpdatePolicyVersionResponse{}
|
||||
if err = userAPI.PerformUpdatePolicyVersion(context.Background(), &userapi.UpdatePolicyVersionRequest{
|
||||
PolicyVersion: consentOpts.Version,
|
||||
Localpart: userID,
|
||||
ServerNoticeUpdate: true,
|
||||
}, res); err != nil {
|
||||
logrus.WithError(err).WithField("userID", userID).Error("failed to update policy version")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if sentMessages > 0 {
|
||||
logrus.Infof("Sent messages to %d users", sentMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func validHMAC(username, userHMAC, secret string) (bool, error) {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
_, err := mac.Write([]byte(username))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
expectedMAC := mac.Sum(nil)
|
||||
decoded, err := hex.DecodeString(userHMAC)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hmac.Equal(decoded, expectedMAC), nil
|
||||
}
|
236
clientapi/routing/consent_tracking_test.go
Normal file
236
clientapi/routing/consent_tracking_test.go
Normal file
|
@ -0,0 +1,236 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func Test_validHMAC(t *testing.T) {
|
||||
type args struct {
|
||||
username string
|
||||
userHMAC string
|
||||
secret string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid hmac",
|
||||
args: args{},
|
||||
wantErr: false,
|
||||
want: false,
|
||||
},
|
||||
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
|
||||
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
|
||||
//
|
||||
{
|
||||
name: "valid hmac",
|
||||
args: args{
|
||||
username: "@alice:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
secret: "helloWorld",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hmac",
|
||||
args: args{
|
||||
username: "@bob:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
secret: "helloWorld",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("validHMAC() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type dummyAPI struct {
|
||||
usersConsent map[string]string
|
||||
}
|
||||
|
||||
func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error {
|
||||
d.usersConsent[req.Localpart] = req.PolicyVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error {
|
||||
res.PolicyVersion = "v2.0"
|
||||
return nil
|
||||
}
|
||||
|
||||
const dummyTemplate = `
|
||||
{{ if .HasConsented }}
|
||||
Consent given.
|
||||
{{ else }}
|
||||
WithoutForm
|
||||
{{ if not .ReadOnly }}
|
||||
With Form.
|
||||
{{ end }}
|
||||
{{ end }}`
|
||||
|
||||
func Test_consent(t *testing.T) {
|
||||
type args struct {
|
||||
username string
|
||||
userHMAC string
|
||||
version string
|
||||
method string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantRespCode int
|
||||
wantBodyContains string
|
||||
}{
|
||||
{
|
||||
name: "not a userID, valid hmac",
|
||||
args: args{
|
||||
username: "notAuserID",
|
||||
userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d",
|
||||
version: "v1.0",
|
||||
method: http.MethodGet,
|
||||
},
|
||||
wantRespCode: http.StatusInternalServerError,
|
||||
},
|
||||
|
||||
// $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld'
|
||||
//(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e
|
||||
//
|
||||
{
|
||||
name: "valid hmac for alice GET, not consented",
|
||||
args: args{
|
||||
username: "@alice:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
version: "v1.0",
|
||||
method: http.MethodGet,
|
||||
},
|
||||
wantRespCode: http.StatusOK,
|
||||
wantBodyContains: "With form",
|
||||
},
|
||||
{
|
||||
name: "alice consents successfully",
|
||||
args: args{
|
||||
username: "@alice:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
version: "v1.0",
|
||||
method: http.MethodPost,
|
||||
},
|
||||
wantRespCode: http.StatusOK,
|
||||
wantBodyContains: "Consent given",
|
||||
},
|
||||
{
|
||||
name: "valid hmac for alice GET, new version",
|
||||
args: args{
|
||||
username: "@alice:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
version: "v2.0",
|
||||
method: http.MethodGet,
|
||||
},
|
||||
wantRespCode: http.StatusOK,
|
||||
wantBodyContains: "With form",
|
||||
},
|
||||
{
|
||||
name: "no hmac provided for alice, read only should be displayed",
|
||||
args: args{
|
||||
username: "@alice:localhost",
|
||||
userHMAC: "",
|
||||
version: "v1.0",
|
||||
method: http.MethodGet,
|
||||
},
|
||||
wantRespCode: http.StatusOK,
|
||||
wantBodyContains: "WithoutForm",
|
||||
},
|
||||
{
|
||||
name: "alice trying to get bobs status is forbidden",
|
||||
args: args{
|
||||
username: "@bob:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
version: "v1.0",
|
||||
method: http.MethodGet,
|
||||
},
|
||||
wantRespCode: http.StatusForbidden,
|
||||
wantBodyContains: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "alice trying to consent for bob is forbidden",
|
||||
args: args{
|
||||
username: "@bob:localhost",
|
||||
userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e",
|
||||
version: "v1.0",
|
||||
method: http.MethodPost,
|
||||
},
|
||||
wantRespCode: http.StatusForbidden,
|
||||
wantBodyContains: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
userAPI := dummyAPI{
|
||||
usersConsent: map[string]string{},
|
||||
}
|
||||
consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate))
|
||||
consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate))
|
||||
userconsentOpts := config.UserConsentOptions{
|
||||
FormSecret: "helloWorld",
|
||||
Version: "v1.0",
|
||||
Templates: consentTemplates,
|
||||
BaseURL: "http://localhost",
|
||||
}
|
||||
cfg := &config.ClientAPI{
|
||||
Matrix: &config.Global{
|
||||
UserConsentOptions: userconsentOpts,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s",
|
||||
userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC,
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(tt.args.method, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
consent(w, req, userAPI, cfg)
|
||||
|
||||
resp := w.Result()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read response body: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantRespCode {
|
||||
t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) {
|
||||
t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -31,13 +31,12 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/tokens"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -721,6 +720,8 @@ func handleRegistrationFlow(
|
|||
}
|
||||
|
||||
switch r.Auth.Type {
|
||||
case authtypes.LoginTypeTerms:
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms)
|
||||
case authtypes.LoginTypeRecaptcha:
|
||||
// Check given captcha response
|
||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||
|
@ -788,11 +789,16 @@ func handleApplicationServiceRegistration(
|
|||
return *err
|
||||
}
|
||||
|
||||
policyVersion := ""
|
||||
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||
policyVersion = cfg.Matrix.UserConsentOptions.Version
|
||||
}
|
||||
|
||||
// If no error, application service was successfully validated.
|
||||
// Don't need to worry about appending to registration stages as
|
||||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
||||
)
|
||||
}
|
||||
|
@ -809,9 +815,13 @@ func checkAndCompleteFlow(
|
|||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||
policyVersion := ""
|
||||
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||
policyVersion = cfg.Matrix.UserConsentOptions.Version
|
||||
}
|
||||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
||||
)
|
||||
}
|
||||
|
@ -834,7 +844,7 @@ func checkAndCompleteFlow(
|
|||
func completeRegistration(
|
||||
ctx context.Context,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
username, password, appserviceID, ipAddr, userAgent, sessionID string,
|
||||
username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string,
|
||||
inhibitLogin eventutil.WeakBoolean,
|
||||
displayName, deviceID *string,
|
||||
accType userapi.AccountType,
|
||||
|
@ -861,11 +871,12 @@ func completeRegistration(
|
|||
}
|
||||
var accRes userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||
AppServiceID: appserviceID,
|
||||
Localpart: username,
|
||||
Password: password,
|
||||
AccountType: accType,
|
||||
OnConflict: userapi.ConflictAbort,
|
||||
AppServiceID: appserviceID,
|
||||
Localpart: username,
|
||||
Password: password,
|
||||
AccountType: accType,
|
||||
OnConflict: userapi.ConflictAbort,
|
||||
PolicyVersion: policyVersion,
|
||||
}, &accRes)
|
||||
if err != nil {
|
||||
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
||||
|
@ -1073,5 +1084,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
|
|||
if ssrr.Admin {
|
||||
accType = userapi.AccountTypeAdmin
|
||||
}
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType)
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
|
@ -93,7 +94,7 @@ func PutTag(
|
|||
}
|
||||
tagContent.Tags[tag] = properties
|
||||
|
||||
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
||||
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
@ -145,7 +146,7 @@ func DeleteTag(
|
|||
}
|
||||
}
|
||||
|
||||
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
||||
if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
@ -191,7 +192,7 @@ func obtainSavedTags(
|
|||
|
||||
// saveTagData saves the provided tag data into the database
|
||||
func saveTagData(
|
||||
req *http.Request,
|
||||
context context.Context,
|
||||
userID string,
|
||||
roomID string,
|
||||
userAPI api.ClientUserAPI,
|
||||
|
@ -208,5 +209,5 @@ func saveTagData(
|
|||
AccountData: json.RawMessage(newTagData),
|
||||
}
|
||||
dataRes := api.InputAccountDataResponse{}
|
||||
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
|
||||
return userAPI.InputAccountData(context, &dataReq, &dataRes)
|
||||
}
|
||||
|
|
|
@ -127,9 +127,13 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// server notifications
|
||||
var (
|
||||
serverNotificationSender *userapi.Device
|
||||
err error
|
||||
)
|
||||
if cfg.Matrix.ServerNotices.Enabled {
|
||||
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
|
||||
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
|
||||
serverNotificationSender, err = getSenderDevice(context.Background(), userAPI, cfg)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
|
||||
}
|
||||
|
@ -177,13 +181,27 @@ func Setup(
|
|||
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
|
||||
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
|
||||
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||
|
||||
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
|
||||
|
||||
// NOTSPEC: consent tracking
|
||||
if cfg.Matrix.UserConsentOptions.Enabled {
|
||||
if !cfg.Matrix.ServerNotices.Enabled {
|
||||
logrus.Warnf("Consent tracking is enabled, but server notes are not. No server notice will be sent to users")
|
||||
} else {
|
||||
// start a new go routine to send messages about consent
|
||||
go sendServerNoticeForConsent(userAPI, rsAPI, &cfg.Matrix.ServerNotices, cfg, serverNotificationSender, asAPI)
|
||||
}
|
||||
publicAPIMux.HandleFunc("/consent", func(writer http.ResponseWriter, request *http.Request) {
|
||||
consent(writer, request, userAPI, cfg)
|
||||
}).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
}
|
||||
|
||||
consentRequiredCheck := httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)
|
||||
|
||||
v3mux.Handle("/createRoom",
|
||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/join/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -213,7 +231,7 @@ func Setup(
|
|||
return PeekRoomByIDOrAlias(
|
||||
req, device, rsAPI, vars["roomIDOrAlias"],
|
||||
)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
}
|
||||
v3mux.Handle("/joined_rooms",
|
||||
|
@ -258,7 +276,7 @@ func Setup(
|
|||
return UnpeekRoomByID(
|
||||
req, device, rsAPI, vars["roomID"],
|
||||
)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/ban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -267,7 +285,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/invite",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -279,7 +297,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/kick",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -288,7 +306,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/unban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -297,7 +315,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -317,7 +335,7 @@ func Setup(
|
|||
txnID := vars["txnID"]
|
||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
|
||||
nil, cfg, rsAPI, transactionsCache)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -326,7 +344,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -335,7 +353,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -343,7 +361,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetAliases(req, rsAPI, device, vars["roomID"])
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -354,7 +372,7 @@ func Setup(
|
|||
eventType := strings.TrimSuffix(vars["type"], "/")
|
||||
eventFormat := req.URL.Query().Get("format") == "event"
|
||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -363,7 +381,7 @@ func Setup(
|
|||
}
|
||||
eventFormat := req.URL.Query().Get("format") == "event"
|
||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -374,7 +392,7 @@ func Setup(
|
|||
emptyString := ""
|
||||
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
||||
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
|
@ -385,7 +403,7 @@ func Setup(
|
|||
}
|
||||
stateKey := vars["stateKey"]
|
||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||
|
@ -487,7 +505,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -497,7 +515,7 @@ func Setup(
|
|||
}
|
||||
txnID := vars["txnId"]
|
||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||
|
@ -508,7 +526,7 @@ func Setup(
|
|||
}
|
||||
txnID := vars["txnID"]
|
||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// This is only here because sytest refers to /unstable for this endpoint
|
||||
|
@ -522,7 +540,7 @@ func Setup(
|
|||
}
|
||||
txnID := vars["txnID"]
|
||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/whoami",
|
||||
|
@ -531,7 +549,7 @@ func Setup(
|
|||
return *r
|
||||
}
|
||||
return Whoami(req, device)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/password",
|
||||
|
@ -738,7 +756,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||
// PUT requests, so we need to allow this method
|
||||
|
@ -763,7 +781,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||
// PUT requests, so we need to allow this method
|
||||
|
@ -771,19 +789,19 @@ func Setup(
|
|||
v3mux.Handle("/account/3pid",
|
||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetAssociated3PIDs(req, userAPI, device)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/3pid",
|
||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
unstableMux.Handle("/account/3pid/delete",
|
||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return Forget3PID(req, userAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
||||
|
@ -798,7 +816,7 @@ func Setup(
|
|||
return *r
|
||||
}
|
||||
return RequestTurnServer(req, device, cfg)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/protocols",
|
||||
|
@ -868,7 +886,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetAdminWhois(req, userAPI, device, vars["userID"])
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
v3mux.Handle("/user/{userID}/openid/request_token",
|
||||
|
@ -881,7 +899,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user_directory/search",
|
||||
|
@ -907,7 +925,7 @@ func Setup(
|
|||
postContent.SearchString,
|
||||
postContent.Limit,
|
||||
)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/members",
|
||||
|
@ -953,7 +971,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendForget(req, device, vars["roomID"], rsAPI)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/upgrade",
|
||||
|
@ -1065,7 +1083,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||
|
@ -1075,7 +1093,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodDelete, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/capabilities",
|
||||
|
@ -1095,11 +1113,11 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return KeyBackupVersion(req, userAPI, device, vars["version"])
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return KeyBackupVersion(req, userAPI, device, "")
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -1107,7 +1125,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -1119,7 +1137,7 @@ func Setup(
|
|||
|
||||
postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CreateKeyBackupVersion(req, userAPI, device)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
@ -1150,7 +1168,7 @@ func Setup(
|
|||
return *resErr
|
||||
}
|
||||
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
// Single room bulk session
|
||||
putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -1182,7 +1200,7 @@ func Setup(
|
|||
}
|
||||
reqBody.Rooms[roomID] = body
|
||||
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
// Single room, single session
|
||||
putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -1215,7 +1233,7 @@ func Setup(
|
|||
}
|
||||
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
||||
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
||||
|
@ -1229,7 +1247,7 @@ func Setup(
|
|||
|
||||
getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -1245,7 +1263,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
@ -1261,11 +1279,11 @@ func Setup(
|
|||
|
||||
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
|
||||
})
|
||||
}, consentRequiredCheck)
|
||||
|
||||
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
@ -1277,12 +1295,12 @@ func Setup(
|
|||
v3mux.Handle("/keys/upload/{deviceID}",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/query",
|
||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -1305,7 +1323,7 @@ func Setup(
|
|||
}
|
||||
|
||||
return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
||||
}),
|
||||
}, consentRequiredCheck),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/presence/{userId}/status",
|
||||
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
|
@ -84,74 +84,66 @@ func SendServerNotice(
|
|||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
res, _ := sendServerNotice(ctx, r, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, txnID, device, txnCache)
|
||||
return res
|
||||
}
|
||||
|
||||
func sendServerNotice(
|
||||
ctx context.Context,
|
||||
serverNoticeRequest sendServerNoticeRequest,
|
||||
rsAPI api.ClientRoomserverAPI,
|
||||
cfgNotices *config.ServerNotices,
|
||||
cfgClient *config.ClientAPI,
|
||||
senderDevice *userapi.Device,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
txnID *string,
|
||||
device *userapi.Device,
|
||||
txnCache *transactions.Cache,
|
||||
) (util.JSONResponse, error) {
|
||||
|
||||
// check that all required fields are set
|
||||
if !r.valid() {
|
||||
if !serverNoticeRequest.valid() {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("Invalid request"),
|
||||
}
|
||||
}, fmt.Errorf("Invalid JSON")
|
||||
}
|
||||
|
||||
// get rooms for specified user
|
||||
allUserRooms := []string{}
|
||||
userRooms := api.QueryRoomsForUserResponse{}
|
||||
// Get rooms the user is either joined, invited or has left.
|
||||
for _, membership := range []string{"join", "invite", "leave"} {
|
||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
||||
UserID: r.UserID,
|
||||
WantMembership: membership,
|
||||
}, &userRooms); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
|
||||
qryServerNoticeRoom := &userapi.QueryServerNoticeRoomResponse{}
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', serverNoticeRequest.UserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("Invalid request"),
|
||||
}, err
|
||||
}
|
||||
err = userAPI.SelectServerNoticeRoomID(ctx, &userapi.QueryServerNoticeRoomRequest{Localpart: localpart}, qryServerNoticeRoom)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err), err
|
||||
}
|
||||
|
||||
// get rooms of the sender
|
||||
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
|
||||
senderRooms := api.QueryRoomsForUserResponse{}
|
||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
||||
UserID: senderUserID,
|
||||
WantMembership: "join",
|
||||
}, &senderRooms); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
// check if we have rooms in common
|
||||
commonRooms := []string{}
|
||||
for _, userRoomID := range allUserRooms {
|
||||
for _, senderRoomID := range senderRooms.RoomIDs {
|
||||
if userRoomID == senderRoomID {
|
||||
commonRooms = append(commonRooms, senderRoomID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(commonRooms) > 1 {
|
||||
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
|
||||
}
|
||||
|
||||
var (
|
||||
roomID string
|
||||
roomVersion = version.DefaultRoomVersion()
|
||||
)
|
||||
roomID := qryServerNoticeRoom.RoomID
|
||||
roomVersion := version.DefaultRoomVersion()
|
||||
|
||||
// create a new room for the user
|
||||
if len(commonRooms) == 0 {
|
||||
if qryServerNoticeRoom.RoomID == "" {
|
||||
var pl, cc []byte
|
||||
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
|
||||
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
|
||||
pl, err := json.Marshal(powerLevelContent)
|
||||
powerLevelContent.Users[serverNoticeRequest.UserID] = -10 // taken from Synapse
|
||||
pl, err = json.Marshal(powerLevelContent)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
return util.ErrorResponse(err), err
|
||||
}
|
||||
createContent := map[string]interface{}{}
|
||||
createContent["m.federate"] = false
|
||||
cc, err := json.Marshal(createContent)
|
||||
cc, err = json.Marshal(createContent)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
return util.ErrorResponse(err), err
|
||||
}
|
||||
crReq := createRoomRequest{
|
||||
Invite: []string{r.UserID},
|
||||
Invite: []string{serverNoticeRequest.UserID},
|
||||
Name: cfgNotices.RoomName,
|
||||
Visibility: "private",
|
||||
Preset: presetPrivateChat,
|
||||
|
@ -166,36 +158,40 @@ func SendServerNotice(
|
|||
switch data := roomRes.JSON.(type) {
|
||||
case createRoomResponse:
|
||||
roomID = data.RoomID
|
||||
|
||||
res := &userapi.UpdateServerNoticeRoomResponse{}
|
||||
err = userAPI.UpdateServerNoticeRoomID(ctx, &userapi.UpdateServerNoticeRoomRequest{RoomID: roomID, Localpart: localpart}, res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("UpdateServerNoticeRoomID failed")
|
||||
return jsonerror.InternalServerError(), err
|
||||
}
|
||||
// tag the room, so we can later check if the user tries to reject an invite
|
||||
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
|
||||
"m.server_notice": {
|
||||
Order: 1.0,
|
||||
},
|
||||
}}
|
||||
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil {
|
||||
if err = saveTagData(ctx, serverNoticeRequest.UserID, roomID, userAPI, serverAlertTag); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
return jsonerror.InternalServerError(), err
|
||||
}
|
||||
|
||||
default:
|
||||
// if we didn't get a createRoomResponse, we probably received an error, so return that.
|
||||
return roomRes
|
||||
return roomRes, fmt.Errorf("Unable to create room")
|
||||
}
|
||||
|
||||
} else {
|
||||
// we've found a room in common, check the membership
|
||||
roomID = commonRooms[0]
|
||||
membershipRes := api.QueryMembershipForUserResponse{}
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
|
||||
res := &api.QueryMembershipForUserResponse{}
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: serverNoticeRequest.UserID, RoomID: roomID}, res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
|
||||
return jsonerror.InternalServerError()
|
||||
return util.ErrorResponse(err), err
|
||||
}
|
||||
if !membershipRes.IsInRoom {
|
||||
// re-invite the user
|
||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
||||
// re-invite the user
|
||||
if res.Membership != gomatrixserverlib.Join {
|
||||
var inviteRes util.JSONResponse
|
||||
inviteRes, err = sendInvite(ctx, userAPI, senderDevice, roomID, serverNoticeRequest.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
||||
if err != nil {
|
||||
return res
|
||||
return inviteRes, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -203,13 +199,13 @@ func SendServerNotice(
|
|||
startedGeneratingEvent := time.Now()
|
||||
|
||||
request := map[string]interface{}{
|
||||
"body": r.Content.Body,
|
||||
"msgtype": r.Content.MsgType,
|
||||
"body": serverNoticeRequest.Content.Body,
|
||||
"msgtype": serverNoticeRequest.Content.MsgType,
|
||||
}
|
||||
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
|
||||
if resErr != nil {
|
||||
logrus.Errorf("failed to send message: %+v", resErr)
|
||||
return *resErr
|
||||
return *resErr, fmt.Errorf("Unable to send event")
|
||||
}
|
||||
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
||||
|
||||
|
@ -224,7 +220,7 @@ func SendServerNotice(
|
|||
// pass the new event to the roomserver and receive the correct event ID
|
||||
// event ID in case of duplicate transaction is discarded
|
||||
startedSubmittingEvent := time.Now()
|
||||
if err := api.SendEvents(
|
||||
if err = api.SendEvents(
|
||||
ctx, rsAPI,
|
||||
api.KindNew,
|
||||
[]*gomatrixserverlib.HeaderedEvent{
|
||||
|
@ -236,7 +232,7 @@ func SendServerNotice(
|
|||
false,
|
||||
); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
|
||||
return jsonerror.InternalServerError()
|
||||
return jsonerror.InternalServerError(), err
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"event_id": e.EventID(),
|
||||
|
@ -259,7 +255,7 @@ func SendServerNotice(
|
|||
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
|
||||
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
|
||||
|
||||
return res
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (r sendServerNoticeRequest) valid() (ok bool) {
|
||||
|
|
|
@ -146,7 +146,12 @@ func main() {
|
|||
logrus.Fatalln("Username is already in use.")
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
|
||||
policyVersion := ""
|
||||
if cfg.Global.UserConsentOptions.Enabled {
|
||||
policyVersion = cfg.Global.UserConsentOptions.Version
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||
}
|
||||
|
|
|
@ -71,6 +71,35 @@ global:
|
|||
# appear in user clients.
|
||||
room_name: "Server Alerts"
|
||||
|
||||
# Consent tracking configuration
|
||||
user_consent:
|
||||
# If the user consent tracking is enabled or not
|
||||
enabled: false
|
||||
# The base URL this homeserver will serve clients on, e.g. https://matrix.org
|
||||
base_url: http://localhost
|
||||
# Randomly generated string (e.g. by using "pwgen -sy 32") to be used to calculate the HMAC
|
||||
form_secret: "superSecretRandomlyGeneratedSecret"
|
||||
# Require consent when user registers for the first time
|
||||
require_at_registration: false
|
||||
# The name to be shown to the user
|
||||
policy_name: "Privacy policy"
|
||||
# The directory to search for templates
|
||||
template_dir: "./templates/privacy"
|
||||
# The version of the policy. When loading templates, ".gohtml" template is added as a suffix
|
||||
# e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to "1.0"
|
||||
version: "1.0"
|
||||
# Send a consent message to guest users
|
||||
send_server_notice_to_guest: false
|
||||
# Default message to send to users
|
||||
server_notice_content:
|
||||
msg_type: "m.text"
|
||||
body: >-
|
||||
Please give your consent to the privacy policy at {{ .ConsentURL }}.
|
||||
# The error message to display if the user hasn't given their consent yet
|
||||
block_events_error: >-
|
||||
You can't send any messages until you consent to the privacy policy at
|
||||
{{ .ConsentURL }}.
|
||||
|
||||
# Configuration for NATS JetStream
|
||||
jetstream:
|
||||
# A list of NATS Server addresses to connect to. If none are specified, an
|
||||
|
|
26
docs/templates/privacy/1.0.gohtml
vendored
Normal file
26
docs/templates/privacy/1.0.gohtml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Privacy policy</title>
|
||||
</head>
|
||||
<body>
|
||||
{{ if .HasConsented }}
|
||||
<p>
|
||||
You have already given your consent.
|
||||
</p>
|
||||
{{ else }}
|
||||
<p>
|
||||
Please give your consent to keep using this homeserver.
|
||||
</p>
|
||||
{{ if not .ReadOnly }}
|
||||
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
|
||||
<form method="post" action="consent">
|
||||
<input type="hidden" name="v" value="{{ .Version }}"/>
|
||||
<input type="hidden" name="u" value="{{ .UserID }}"/>
|
||||
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
|
||||
<input type="submit" value="I consent"/>
|
||||
</form>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
</body>
|
||||
</html>
|
|
@ -15,6 +15,8 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -25,9 +27,12 @@ import (
|
|||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
@ -41,10 +46,24 @@ type BasicAuth struct {
|
|||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
// AuthAPICheck is an option to MakeAuthAPI to add additional checks (e.g. WithConsentCheck) to verify
|
||||
// the user is allowed to do specific things.
|
||||
type AuthAPICheck func(ctx context.Context, device *userapi.Device) *util.JSONResponse
|
||||
|
||||
// WithConsentCheck checks that a user has given his consent.
|
||||
func WithConsentCheck(options config.UserConsentOptions, api userapi.QueryPolicyVersionAPI) AuthAPICheck {
|
||||
return func(ctx context.Context, device *userapi.Device) *util.JSONResponse {
|
||||
if !options.Enabled {
|
||||
return nil
|
||||
}
|
||||
return checkConsent(ctx, device.UserID, api, options)
|
||||
}
|
||||
}
|
||||
|
||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
||||
func MakeAuthAPI(
|
||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPICheck,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
@ -72,6 +91,14 @@ func MakeAuthAPI(
|
|||
}
|
||||
}()
|
||||
|
||||
// apply additional checks, if any
|
||||
for _, opt := range checks {
|
||||
resp := opt(req.Context(), device)
|
||||
if resp != nil {
|
||||
return *resp
|
||||
}
|
||||
}
|
||||
|
||||
jsonRes := f(req, device)
|
||||
// do not log 4xx as errors as they are client fails, not server fails
|
||||
if hub != nil && jsonRes.Code >= 500 {
|
||||
|
@ -83,6 +110,53 @@ func MakeAuthAPI(
|
|||
return MakeExternalAPI(metricsName, h)
|
||||
}
|
||||
|
||||
func checkConsent(ctx context.Context, userID string, userAPI userapi.QueryPolicyVersionAPI, userConsentCfg config.UserConsentOptions) *util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// check which version of the policy the user accepted
|
||||
res := &userapi.QueryPolicyVersionResponse{}
|
||||
err = userAPI.QueryPolicyVersion(ctx, &userapi.QueryPolicyVersionRequest{
|
||||
Localpart: localpart,
|
||||
}, res)
|
||||
if err != nil {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to get policy version"),
|
||||
}
|
||||
}
|
||||
|
||||
// user hasn't accepted any policy, block access.
|
||||
if userConsentCfg.Version != res.PolicyVersion {
|
||||
uri, err := userConsentCfg.ConsentURL(userID)
|
||||
if err != nil {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to get consent URL"),
|
||||
}
|
||||
}
|
||||
msg := &bytes.Buffer{}
|
||||
c := struct {
|
||||
ConsentURL string
|
||||
}{
|
||||
ConsentURL: uri,
|
||||
}
|
||||
if err = userConsentCfg.TextTemplates.ExecuteTemplate(msg, "blockEventsError", c); err != nil {
|
||||
logrus.Infof("error consent message: %+v", err)
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to execute template"),
|
||||
}
|
||||
}
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.ConsentNotGiven(uri, msg.String()),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
|
||||
// This is used for APIs that are called from the internet.
|
||||
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
|
||||
|
|
|
@ -59,15 +59,12 @@ func Setup(
|
|||
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
||||
}
|
||||
|
||||
uploadHandler := httputil.MakeAuthAPI(
|
||||
"upload", userAPI,
|
||||
func(req *http.Request, dev *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
|
||||
},
|
||||
)
|
||||
uploadHandler := httputil.MakeAuthAPI("upload", userAPI, func(req *http.Request, dev *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return Upload(req, cfg, dev, db, activeThumbnailGeneration)
|
||||
})
|
||||
|
||||
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
|
|
|
@ -268,6 +268,7 @@ func (b *BaseDendrite) Close() error {
|
|||
func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) {
|
||||
if dbProperties.ConnectionString != "" || b == nil {
|
||||
// Open a new database connection using the supplied config.
|
||||
logrus.Infof("Open a new database connection using the supplied config.: %+v", dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties, writer)
|
||||
return db, writer, err
|
||||
}
|
||||
|
|
|
@ -265,6 +265,21 @@ func loadConfig(
|
|||
return &c, nil
|
||||
}
|
||||
|
||||
type Terms struct {
|
||||
Policies Policies `json:"policies"`
|
||||
}
|
||||
type En struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
type PrivacyPolicy struct {
|
||||
En En `json:"en"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
type Policies struct {
|
||||
PrivacyPolicy PrivacyPolicy `json:"privacy_policy"`
|
||||
}
|
||||
|
||||
// Derive generates data that is derived from various values provided in
|
||||
// the config file.
|
||||
func (config *Dendrite) Derive() error {
|
||||
|
@ -275,13 +290,39 @@ func (config *Dendrite) Derive() error {
|
|||
// TODO: Add email auth type
|
||||
// TODO: Add MSISDN auth type
|
||||
|
||||
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
|
||||
uri := config.Global.UserConsentOptions.BaseURL + "/_matrix/client/consent?v=" + config.Global.UserConsentOptions.Version
|
||||
config.Derived.Registration.Params[authtypes.LoginTypeTerms] = Terms{
|
||||
Policies: Policies{
|
||||
PrivacyPolicy: PrivacyPolicy{
|
||||
En: En{
|
||||
Name: config.Global.UserConsentOptions.PolicyName,
|
||||
URL: uri,
|
||||
},
|
||||
Version: config.Global.UserConsentOptions.Version,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
if config.ClientAPI.RecaptchaEnabled {
|
||||
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
|
||||
} else {
|
||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
|
||||
}
|
||||
|
||||
if config.Derived.Registration.Flows == nil {
|
||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypeDummy},
|
||||
})
|
||||
}
|
||||
|
||||
// prepend each flow with LoginTypeTerms or LoginTypeRecaptcha
|
||||
for i, flow := range config.Derived.Registration.Flows {
|
||||
if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration {
|
||||
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeTerms}, flow.Stages...)
|
||||
}
|
||||
if config.ClientAPI.RecaptchaEnabled {
|
||||
flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeRecaptcha}, flow.Stages...)
|
||||
}
|
||||
config.Derived.Registration.Flows[i] = flow
|
||||
}
|
||||
|
||||
// Load application service configuration files
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
textTemplate "text/template"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -71,6 +79,9 @@ type Global struct {
|
|||
// ServerNotices configuration used for sending server notices
|
||||
ServerNotices ServerNotices `yaml:"server_notices"`
|
||||
|
||||
// Consent tracking options
|
||||
UserConsentOptions UserConsentOptions `yaml:"user_consent"`
|
||||
|
||||
// ReportStats configures opt-in anonymous stats reporting.
|
||||
ReportStats ReportStats `yaml:"report_stats"`
|
||||
}
|
||||
|
@ -88,6 +99,7 @@ func (c *Global) Defaults(generate bool) {
|
|||
c.Metrics.Defaults(generate)
|
||||
c.DNSCache.Defaults()
|
||||
c.Sentry.Defaults()
|
||||
c.UserConsentOptions.Defaults()
|
||||
c.ServerNotices.Defaults(generate)
|
||||
c.ReportStats.Defaults()
|
||||
}
|
||||
|
@ -100,6 +112,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
|||
c.Metrics.Verify(configErrs, isMonolith)
|
||||
c.Sentry.Verify(configErrs, isMonolith)
|
||||
c.DNSCache.Verify(configErrs, isMonolith)
|
||||
c.UserConsentOptions.Verify(configErrs, isMonolith)
|
||||
c.ServerNotices.Verify(configErrs, isMonolith)
|
||||
c.ReportStats.Verify(configErrs, isMonolith)
|
||||
}
|
||||
|
@ -261,6 +274,102 @@ func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
|||
checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime))
|
||||
}
|
||||
|
||||
// Consent tracking configuration
|
||||
// If either require_at_registration or send_server_notice_to_guest are true, consent
|
||||
// messages will be sent to the users.
|
||||
type UserConsentOptions struct {
|
||||
// If consent tracking is enabled or not
|
||||
Enabled bool `yaml:"enabled"`
|
||||
// Randomly generated string to be used to calculate the HMAC
|
||||
FormSecret string `yaml:"form_secret"`
|
||||
// Require consent when user registers for the first time
|
||||
RequireAtRegistration bool `yaml:"require_at_registration"`
|
||||
// The name to be shown to the user
|
||||
PolicyName string `yaml:"policy_name"`
|
||||
// The directory to search for *.gohtml templates
|
||||
TemplateDir string `yaml:"template_dir"`
|
||||
// The version of the policy. When loading templates, ".gohtml" template is added as a suffix
|
||||
// e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to 1.0
|
||||
Version string `yaml:"version"`
|
||||
// Send a consent message to guest users
|
||||
SendServerNoticeToGuest bool `yaml:"send_server_notice_to_guest"`
|
||||
// Default message to send to users
|
||||
ServerNoticeContent struct {
|
||||
MsgType string `yaml:"msg_type"`
|
||||
Body string `yaml:"body"`
|
||||
} `yaml:"server_notice_content"`
|
||||
// The error message to display if the user hasn't given their consent yet
|
||||
BlockEventsError string `yaml:"block_events_error"`
|
||||
// All loaded templates
|
||||
Templates *template.Template `yaml:"-"`
|
||||
TextTemplates *textTemplate.Template `yaml:"-"`
|
||||
// The base URL this homeserver will serve clients on, e.g. https://matrix.org
|
||||
BaseURL string `yaml:"base_url"`
|
||||
}
|
||||
|
||||
func (c *UserConsentOptions) Defaults() {
|
||||
c.Enabled = false
|
||||
c.RequireAtRegistration = false
|
||||
c.SendServerNoticeToGuest = false
|
||||
c.PolicyName = "Privacy Policy"
|
||||
c.Version = "1.0"
|
||||
c.TemplateDir = "./templates/privacy"
|
||||
}
|
||||
|
||||
func (c *UserConsentOptions) Verify(configErrors *ConfigErrors, isMonolith bool) {
|
||||
if !c.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
checkNotEmpty(configErrors, "template_dir", c.TemplateDir)
|
||||
checkNotEmpty(configErrors, "version", c.Version)
|
||||
checkNotEmpty(configErrors, "policy_name", c.PolicyName)
|
||||
checkNotEmpty(configErrors, "form_secret", c.FormSecret)
|
||||
checkNotEmpty(configErrors, "base_url", c.BaseURL)
|
||||
if len(*configErrors) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p, err := filepath.Abs(c.TemplateDir)
|
||||
if err != nil {
|
||||
configErrors.Add("unable to get template directory")
|
||||
return
|
||||
}
|
||||
|
||||
c.TextTemplates = textTemplate.Must(textTemplate.New("blockEventsError").Parse(c.BlockEventsError))
|
||||
c.TextTemplates = textTemplate.Must(c.TextTemplates.New("serverNoticeTemplate").Parse(c.ServerNoticeContent.Body))
|
||||
|
||||
// Read all defined *.gohtml templates
|
||||
t, err := template.ParseGlob(filepath.Join(p, "*.gohtml"))
|
||||
if err != nil || t == nil {
|
||||
configErrors.Add(fmt.Sprintf("unable to read consent templates: %+v", err))
|
||||
return
|
||||
}
|
||||
c.Templates = t
|
||||
// Verify we've got a template for the defined version
|
||||
versionTemplate := c.Templates.Lookup(c.Version + ".gohtml")
|
||||
if versionTemplate == nil {
|
||||
configErrors.Add(fmt.Sprintf("unable to load defined '%s' policy template", c.Version))
|
||||
}
|
||||
}
|
||||
|
||||
// ConsentURL constructs the URL shown to users to accept the TOS
|
||||
func (c *UserConsentOptions) ConsentURL(userID string) (string, error) {
|
||||
mac := hmac.New(sha256.New, []byte(c.FormSecret))
|
||||
_, err := mac.Write([]byte(userID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userMAC := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("u", userID)
|
||||
params.Add("h", userMAC)
|
||||
params.Add("v", c.Version)
|
||||
|
||||
return fmt.Sprintf("%s/_matrix/client/consent?%s", c.BaseURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
// PresenceOptions defines possible configurations for presence events.
|
||||
type PresenceOptions struct {
|
||||
// Whether inbound presence events are allowed
|
||||
|
|
115
setup/config/config_global_test.go
Normal file
115
setup/config/config_global_test.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserConsentOptions_Verify(t *testing.T) {
|
||||
type args struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields UserConsentOptions
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "template dir not set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "template dir set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
TemplateDir: "testdata/privacy",
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "policy name not set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
TemplateDir: "testdata/privacy",
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "policy name set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
TemplateDir: "testdata/privacy",
|
||||
PolicyName: "Privacy policy",
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "version not set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
TemplateDir: "testdata/privacy",
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "everyhing required set",
|
||||
fields: UserConsentOptions{
|
||||
RequireAtRegistration: true,
|
||||
TemplateDir: "./testdata/privacy",
|
||||
Version: "1.0",
|
||||
PolicyName: "Privacy policy",
|
||||
FormSecret: "helloWorld",
|
||||
BaseURL: "http://localhost",
|
||||
},
|
||||
args: struct {
|
||||
configErrors *ConfigErrors
|
||||
isMonolith bool
|
||||
}{configErrors: &ConfigErrors{}, isMonolith: true},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &UserConsentOptions{
|
||||
Enabled: true,
|
||||
BaseURL: tt.fields.BaseURL,
|
||||
FormSecret: tt.fields.FormSecret,
|
||||
RequireAtRegistration: tt.fields.RequireAtRegistration,
|
||||
PolicyName: tt.fields.PolicyName,
|
||||
Version: tt.fields.Version,
|
||||
TemplateDir: tt.fields.TemplateDir,
|
||||
SendServerNoticeToGuest: tt.fields.SendServerNoticeToGuest,
|
||||
ServerNoticeContent: tt.fields.ServerNoticeContent,
|
||||
BlockEventsError: tt.fields.BlockEventsError,
|
||||
}
|
||||
c.Verify(tt.args.configErrors, tt.args.isMonolith)
|
||||
if !tt.wantErr && len(*tt.args.configErrors) > 0 {
|
||||
t.Errorf("expected no errors, got '%+v'", tt.args.configErrors)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
26
setup/config/testdata/privacy/1.0.gohtml
vendored
Normal file
26
setup/config/testdata/privacy/1.0.gohtml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Privacy policy</title>
|
||||
</head>
|
||||
<body>
|
||||
{{ if .HasConsented }}
|
||||
<p>
|
||||
You have already given your consent.
|
||||
</p>
|
||||
{{ else }}
|
||||
<p>
|
||||
Please give your consent to keep using this homeserver.
|
||||
</p>
|
||||
{{ if not .ReadOnly }}
|
||||
<!-- The variables used here are only provided when the 'u' param is given to the homeserver -->
|
||||
<form method="post" action="consent">
|
||||
<input type="hidden" name="v" value="{{ .Version }}"/>
|
||||
<input type="hidden" name="u" value="{{ .UserID }}"/>
|
||||
<input type="hidden" name="h" value="{{ .UserHMAC }}"/>
|
||||
<input type="submit" value="I consent"/>
|
||||
</form>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
</body>
|
||||
</html>
|
|
@ -44,7 +44,6 @@ func ParseFlags(monolith bool) *config.Dendrite {
|
|||
}
|
||||
|
||||
cfg, err := config.Load(*configPath, monolith)
|
||||
|
||||
if err != nil {
|
||||
logrus.Fatalf("Invalid config file: %s", err)
|
||||
}
|
||||
|
|
|
@ -93,6 +93,6 @@ func Setup(
|
|||
vars["roomId"], vars["eventId"],
|
||||
lazyLoadCache,
|
||||
)
|
||||
}),
|
||||
}, httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
|
|
@ -66,6 +66,7 @@ type FederationUserAPI interface {
|
|||
// api functions required by the sync api
|
||||
type SyncUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
QueryPolicyVersionAPI
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||
|
@ -78,6 +79,7 @@ type ClientUserAPI interface {
|
|||
QueryAcccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
UserConsentPolicyAPI
|
||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
|
@ -106,6 +108,18 @@ type ClientUserAPI interface {
|
|||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
|
||||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
|
||||
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
|
||||
}
|
||||
|
||||
type UserConsentPolicyAPI interface {
|
||||
QueryPolicyVersionAPI
|
||||
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
|
||||
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
||||
}
|
||||
|
||||
type QueryPolicyVersionAPI interface {
|
||||
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
||||
}
|
||||
|
||||
// custom api functions required by pinecone / p2p demos
|
||||
|
@ -318,12 +332,12 @@ type QuerySearchProfilesResponse struct {
|
|||
|
||||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||
type PerformAccountCreationRequest struct {
|
||||
AccountType AccountType // Required: whether this is a guest or user account
|
||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
OnConflict Conflict
|
||||
AccountType AccountType // Required: whether this is a guest or user account
|
||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||
PolicyVersion string // optional: the privacy policy this account has accepted
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
OnConflict Conflict
|
||||
}
|
||||
|
||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||
|
@ -412,6 +426,53 @@ type QueryOpenIDTokenResponse struct {
|
|||
ExpiresAtMS int64
|
||||
}
|
||||
|
||||
// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
|
||||
type QueryPolicyVersionRequest struct {
|
||||
Localpart string
|
||||
}
|
||||
|
||||
// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
|
||||
type QueryPolicyVersionResponse struct {
|
||||
PolicyVersion string
|
||||
}
|
||||
|
||||
// QueryOutdatedPolicyRequest is the request for QueryOutdatedPolicyRequest
|
||||
type QueryOutdatedPolicyRequest struct {
|
||||
PolicyVersion string
|
||||
}
|
||||
|
||||
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
|
||||
type QueryOutdatedPolicyResponse struct {
|
||||
UserLocalparts []string
|
||||
}
|
||||
|
||||
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||
type UpdatePolicyVersionRequest struct {
|
||||
PolicyVersion, Localpart string
|
||||
ServerNoticeUpdate bool
|
||||
}
|
||||
|
||||
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||
type UpdatePolicyVersionResponse struct{}
|
||||
|
||||
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
|
||||
type QueryServerNoticeRoomRequest struct {
|
||||
Localpart string
|
||||
}
|
||||
|
||||
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
|
||||
type QueryServerNoticeRoomResponse struct {
|
||||
RoomID string
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
|
||||
type UpdateServerNoticeRoomRequest struct {
|
||||
Localpart, RoomID string
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
|
||||
type UpdateServerNoticeRoomResponse struct{}
|
||||
|
||||
// Device represents a client's device (mobile, web, etc)
|
||||
type Device struct {
|
||||
ID string
|
||||
|
|
|
@ -203,6 +203,36 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
|||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error {
|
||||
err := t.Impl.QueryPolicyVersion(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error {
|
||||
err := t.Impl.QueryOutdatedPolicy(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryOutdatedPolicy req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
|
||||
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
|
||||
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
|
||||
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func js(thing interface{}) string {
|
||||
b, err := json.Marshal(thing)
|
||||
if err != nil {
|
||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
@ -833,3 +833,60 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
|
|||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
||||
func (a *UserInternalAPI) QueryPolicyVersion(
|
||||
ctx context.Context,
|
||||
req *api.QueryPolicyVersionRequest,
|
||||
res *api.QueryPolicyVersionResponse,
|
||||
) error {
|
||||
var err error
|
||||
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOutdatedPolicy(
|
||||
ctx context.Context,
|
||||
req *api.QueryOutdatedPolicyRequest,
|
||||
res *api.QueryOutdatedPolicyResponse,
|
||||
) error {
|
||||
var err error
|
||||
res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||
ctx context.Context,
|
||||
req *api.UpdatePolicyVersionRequest,
|
||||
res *api.UpdatePolicyVersionResponse,
|
||||
) error {
|
||||
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) SelectServerNoticeRoomID(
|
||||
ctx context.Context,
|
||||
req *api.QueryServerNoticeRoomRequest,
|
||||
res *api.QueryServerNoticeRoomResponse,
|
||||
) (err error) {
|
||||
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.RoomID = roomID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
|
||||
ctx context.Context,
|
||||
req *api.UpdateServerNoticeRoomRequest,
|
||||
res *api.UpdateServerNoticeRoomResponse,
|
||||
) (err error) {
|
||||
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
|
||||
}
|
||||
|
|
|
@ -44,6 +44,8 @@ const (
|
|||
PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
|
||||
PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
|
||||
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
|
||||
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
||||
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
|
||||
|
||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||
QueryProfilePath = "/userapi/queryProfile"
|
||||
|
@ -61,6 +63,9 @@ const (
|
|||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
||||
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
||||
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
@ -391,3 +396,43 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context
|
|||
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryOutdatedPolicy(ctx context.Context, req *api.QueryOutdatedPolicyRequest, res *api.QueryOutdatedPolicyResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOutdatedPolicy")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUpdatePolicyVersion")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryServerNoticeRoomPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryPolicyVersionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
|
|
@ -457,4 +457,74 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryPolicyVersionPath,
|
||||
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryPolicyVersionRequest{}
|
||||
response := api.QueryPolicyVersionResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.QueryPolicyVersion(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
|
||||
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryOutdatedPolicyRequest{}
|
||||
response := api.QueryOutdatedPolicyResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.QueryOutdatedPolicy(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
|
||||
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
|
||||
request := api.UpdatePolicyVersionRequest{}
|
||||
response := api.UpdatePolicyVersionResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryServerNoticeRoomPath,
|
||||
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryServerNoticeRoomRequest{}
|
||||
response := api.QueryServerNoticeRoomResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
|
||||
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||
request := api.UpdateServerNoticeRoomRequest{}
|
||||
response := api.UpdateServerNoticeRoomResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ type Account interface {
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
|
@ -126,9 +126,18 @@ type Notification interface {
|
|||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type ConsentTracking interface {
|
||||
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
|
||||
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
|
||||
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Account
|
||||
AccountData
|
||||
ConsentTracking
|
||||
Device
|
||||
KeyBackup
|
||||
LoginToken
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
|
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
-- If the account is currently active
|
||||
is_deactivated BOOLEAN DEFAULT FALSE,
|
||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||
account_type SMALLINT NOT NULL
|
||||
account_type SMALLINT NOT NULL,
|
||||
-- The policy version this user has accepted
|
||||
policy_version TEXT,
|
||||
-- The policy version the user received from the server notices room
|
||||
policy_version_sent TEXT,
|
||||
server_notice_room_id TEXT
|
||||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
@ -67,14 +73,38 @@ const selectPasswordHashSQL = "" +
|
|||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
||||
|
||||
const selectPrivacyPolicySQL = "" +
|
||||
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const batchSelectPrivacyPolicySQL = "" +
|
||||
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
|
||||
|
||||
const updatePolicyVersionSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||
|
||||
const updatePolicyVersionServerNoticeSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||
|
||||
const selectServerNoticeRoomSQL = "" +
|
||||
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const updateServerNoticeRoomSQL = "" +
|
||||
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
selectPrivacyPolicyStmt *sql.Stmt
|
||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||
updatePolicyVersionStmt *sql.Stmt
|
||||
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||
selectServerNoticeRoomStmt *sql.Stmt
|
||||
updateServerNoticeRoomStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||
|
@ -92,6 +122,12 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
|||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -99,16 +135,16 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -178,3 +214,71 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return id + 1, err
|
||||
}
|
||||
|
||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
var policyNull sql.NullString
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
|
||||
return policyNull.String, err
|
||||
}
|
||||
|
||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, policyVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return userIDs, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) UpdatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
if serverNotice {
|
||||
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (roomID string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||
|
||||
roomIDNull := sql.NullString{}
|
||||
row := stmt.QueryRowContext(ctx, localpart)
|
||||
err = row.Scan(&roomIDNull)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// roomIDNull.String is either the roomID or an empty string
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpIsActive, DownIsActive)
|
||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||
}
|
||||
|
||||
func LoadIsActive(m *sqlutil.Migrations) {
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||
}
|
||||
|
||||
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -46,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
deltas.LoadIsActive(m)
|
||||
//deltas.LoadLastSeenTSIP(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
deltas.LoadAddPolicyVersion(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -125,7 +126,7 @@ func (d *Database) SetPassword(
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
|
@ -139,7 +140,7 @@ func (d *Database) CreateAccount(
|
|||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -148,7 +149,7 @@ func (d *Database) CreateAccount(
|
|||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
@ -160,7 +161,8 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil {
|
||||
logrus.WithError(err).Error("d.Accounts.InsertAccount error")
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
|
@ -763,3 +765,42 @@ func (d *Database) RemovePushers(
|
|||
func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) {
|
||||
return d.Stats.UserStatistics(ctx, nil)
|
||||
}
|
||||
|
||||
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// SelectServerNoticeRoomID returns the server notice room, if one is set.
|
||||
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
|
||||
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomID updates the server notice room
|
||||
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
|
@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
-- If the account is currently active
|
||||
is_deactivated BOOLEAN DEFAULT 0,
|
||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||
account_type INTEGER NOT NULL
|
||||
account_type INTEGER NOT NULL,
|
||||
-- The policy version this user has accepted
|
||||
policy_version TEXT,
|
||||
-- The policy version the user received from the server notices room
|
||||
policy_version_sent TEXT,
|
||||
server_notice_room_id TEXT
|
||||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
@ -67,15 +73,39 @@ const selectPasswordHashSQL = "" +
|
|||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
|
||||
const selectPrivacyPolicySQL = "" +
|
||||
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const batchSelectPrivacyPolicySQL = "" +
|
||||
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
|
||||
|
||||
const updatePolicyVersionSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||
|
||||
const updatePolicyVersionServerNoticeSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||
|
||||
const selectServerNoticeRoomSQL = "" +
|
||||
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const updateServerNoticeRoomSQL = "" +
|
||||
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
db *sql.DB
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
selectPrivacyPolicyStmt *sql.Stmt
|
||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||
updatePolicyVersionStmt *sql.Stmt
|
||||
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||
selectServerNoticeRoomStmt *sql.Stmt
|
||||
updateServerNoticeRoomStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||
|
@ -94,6 +124,12 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -101,16 +137,16 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -183,3 +219,72 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
}
|
||||
return id + 1, err
|
||||
}
|
||||
|
||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||
|
||||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
var policyNull sql.NullString
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||
err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull)
|
||||
return policyNull.String, err
|
||||
}
|
||||
|
||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return userIDs, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) UpdatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
if serverNotice {
|
||||
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (roomID string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||
|
||||
roomIDNull := sql.NullString{}
|
||||
row := stmt.QueryRowContext(ctx, localpart)
|
||||
err = row.Scan(&roomIDNull)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// roomIDNull.String is either the roomID or an empty string
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpIsActive, DownIsActive)
|
||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||
}
|
||||
|
||||
func LoadIsActive(m *sqlutil.Migrations) {
|
||||
|
|
|
@ -4,15 +4,9 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||
}
|
||||
|
||||
func LoadAddAccountType(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpAddAccountType, DownAddAccountType)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
|
||||
}
|
||||
|
||||
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -18,11 +18,10 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
|
@ -47,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
deltas.LoadIsActive(m)
|
||||
//deltas.LoadLastSeenTSIP(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
deltas.LoadAddPolicyVersion(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ func Test_Accounts(t *testing.T) {
|
|||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
|
@ -102,7 +102,7 @@ func Test_Accounts(t *testing.T) {
|
|||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", "v1.0", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
@ -350,7 +350,7 @@ func Test_Profile(t *testing.T) {
|
|||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
|
|
|
@ -32,12 +32,18 @@ type AccountDataTable interface {
|
|||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
|
||||
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
|
||||
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
|
||||
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
|
|
|
@ -88,7 +88,7 @@ func mustMakeAccountAndDevice(
|
|||
appServiceID = util.RandomString(16)
|
||||
}
|
||||
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", "", appServiceID, accType)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ func TestQueryProfile(t *testing.T) {
|
|||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func TestLoginToken(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue