diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 124940f71..4a1720295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,7 +63,7 @@ jobs: # Run Complement - run: | set -o pipefail && - go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/.gitignore b/.gitignore index 092f4501c..2a8c2cf55 100644 --- a/.gitignore +++ b/.gitignore @@ -54,7 +54,7 @@ dendrite.yaml *.db # Log files -*.log* +*.log* # Generated code cmd/dendrite-demo-yggdrasil/embed/fs*.go @@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go # Test dependencies test/wasm/node_modules -media_store/ +# Ignore complement folder when running locally +complement/ +media_store/ diff --git a/build/docker/config/dendrite.yaml b/build/docker/config/dendrite.yaml index 6d5ebc9fd..ebae50132 100644 --- a/build/docker/config/dendrite.yaml +++ b/build/docker/config/dendrite.yaml @@ -318,6 +318,17 @@ user_api: max_idle_conns: 2 conn_max_lifetime: -1 +# Configuration for the Push Server API. +push_server: + internal_api: + listen: http://localhost:7782 + connect: http://localhost:7782 + database: + connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index aa8cc6e6e..5ab90adaf 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 8b9c88f2a..3329485aa 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9cab7956c..4db75809f 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) { delete(u.Sessions, sessionID) } +type Challenge struct { + Completed []string `json:"completed"` + Flows []userInteractiveFlow `json:"flows"` + Session string `json:"session"` + // TODO: Return any additional `params` + Params map[string]interface{} `json:"params"` +} + // Challenge returns an HTTP 401 with the supported flows for authenticating func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { return &util.JSONResponse{ Code: 401, - JSON: struct { - Completed []string `json:"completed"` - Flows []userInteractiveFlow `json:"flows"` - Session string `json:"session"` - // TODO: Return any additional `params` - Params map[string]interface{} `json:"params"` - }{ - u.Completed, - u.Flows, - sessionID, - make(map[string]interface{}), + JSON: Challenge{ + Completed: u.Completed, + Flows: u.Flows, + Session: sessionID, + Params: make(map[string]interface{}), }, } } diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index b1eb391f8..91bc2fdb3 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -59,7 +59,8 @@ func AddPublicRoutes( routing.Setup( router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, accountsDB, userAPI, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, + syncProducer, transactionsCache, fsAPI, keyAPI, + extRoomsProvider, mscCfg, ) } diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 9b1d6b1a2..9ab90391d 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -30,7 +30,7 @@ type SyncAPIProducer struct { } // SendData sends account data to the sync API server -func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { +func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error { m := &nats.Msg{ Subject: p.Topic, Header: nats.Header{}, @@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string m.Header.Set(jetstream.UserID, userID) data := eventutil.AccountData{ - RoomID: roomID, - Type: dataType, + RoomID: roomID, + Type: dataType, + ReadMarker: readMarker, } var err error m.Data, err = json.Marshal(data) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 03025f1da..d8e982690 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -127,7 +128,7 @@ func SaveAccountData( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, dataType); err != nil { + if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } @@ -138,11 +139,6 @@ func SaveAccountData( } } -type readMarkerJSON struct { - FullyRead string `json:"m.fully_read"` - Read string `json:"m.read"` -} - type fullyReadEvent struct { EventID string `json:"event_id"` } @@ -159,7 +155,7 @@ func SaveReadMarker( return *resErr } - var r readMarkerJSON + var r eventutil.ReadMarkerJSON resErr = httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -189,7 +185,7 @@ func SaveReadMarker( return util.ErrorResponse(err) } - if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil { + if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 9f54a625a..161bc2731 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/tidwall/gjson" ) // https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices @@ -163,6 +164,15 @@ func DeleteDeviceById( req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { + var ( + deleteOK bool + sessionID string + ) + defer func() { + if deleteOK { + sessions.deleteSession(sessionID) + } + }() ctx := req.Context() defer req.Body.Close() // nolint:errcheck bodyBytes, err := ioutil.ReadAll(req.Body) @@ -172,8 +182,29 @@ func DeleteDeviceById( JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), } } + + // check that we know this session, and it matches with the device to delete + s := gjson.GetBytes(bodyBytes, "auth.session").Str + if dev, ok := sessions.getDeviceToDelete(s); ok { + if dev != deviceID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("session & device mismatch"), + } + } + } + + if s != "" { + sessionID = s + } + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) if errRes != nil { + switch data := errRes.JSON.(type) { + case auth.Challenge: + sessions.addDeviceToDelete(data.Session, deviceID) + default: + } return *errRes } @@ -201,6 +232,8 @@ func DeleteDeviceById( return jsonerror.InternalServerError() } + deleteOK = true + return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go new file mode 100644 index 000000000..ee715d323 --- /dev/null +++ b/clientapi/routing/notification.go @@ -0,0 +1,63 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetNotifications handles /_matrix/client/r0/notifications +func GetNotifications( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var limit int64 + if limitStr := req.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.ParseInt(limitStr, 10, 64) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") + return jsonerror.InternalServerError() + } + } + + var queryRes userapi.QueryNotificationsResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ + Localpart: localpart, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") + return jsonerror.InternalServerError() + } + util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index acac60fa5..c63412d08 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -12,6 +12,7 @@ import ( userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) type newPasswordRequest struct { @@ -37,6 +38,11 @@ func Password( var r newPasswordRequest r.LogoutDevices = true + logrus.WithFields(logrus.Fields{ + "sessionId": device.SessionID, + "userId": device.UserID, + }).Debug("Changing password") + // Unmarshal the request. resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -116,6 +122,15 @@ func Password( util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } + + pushersReq := &api.PerformPusherDeletionRequest{ + Localpart: localpart, + SessionID: device.SessionID, + } + if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") + return jsonerror.InternalServerError() + } } // Return a success code. diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 717cbda75..4aab03714 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -286,7 +286,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go new file mode 100644 index 000000000..9d6bef8bd --- /dev/null +++ b/clientapi/routing/pusher.go @@ -0,0 +1,114 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "net/url" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetPushers handles /_matrix/client/r0/pushers +func GetPushers( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var queryRes userapi.QueryPushersResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ + Localpart: localpart, + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") + return jsonerror.InternalServerError() + } + for i := range queryRes.Pushers { + queryRes.Pushers[i].SessionID = 0 + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} + +// SetPusher handles /_matrix/client/r0/pushers/set +// This endpoint allows the creation, modification and deletion of pushers for this user ID. +// The behaviour of this endpoint varies depending on the values in the JSON body. +func SetPusher( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + body := userapi.PerformPusherSetRequest{} + if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { + return *resErr + } + if len(body.AppID) > 64 { + return invalidParam("length of app_id must be no more than 64 characters") + } + if len(body.PushKey) > 512 { + return invalidParam("length of pushkey must be no more than 512 bytes") + } + uInt := body.Data["url"] + if uInt != nil { + u, ok := uInt.(string) + if !ok { + return invalidParam("url must be string") + } + if u != "" { + var pushUrl *url.URL + pushUrl, err = url.Parse(u) + if err != nil { + return invalidParam("malformed url passed") + } + if pushUrl.Scheme != "https" { + return invalidParam("only https scheme is allowed") + } + } + + } + body.Localpart = localpart + body.SessionID = device.SessionID + err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func invalidParam(msg string) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam(msg), + } +} diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go new file mode 100644 index 000000000..81a33b25a --- /dev/null +++ b/clientapi/routing/pushrules.go @@ -0,0 +1,386 @@ +package routing + +import ( + "context" + "encoding/json" + "io" + "net/http" + "reflect" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/pushrules" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { + if eerr, ok := err.(*jsonerror.MatrixError); ok { + var status int + switch eerr.ErrCode { + case "M_INVALID_ARGUMENT_VALUE": + status = http.StatusBadRequest + case "M_NOT_FOUND": + status = http.StatusNotFound + default: + status = http.StatusInternalServerError + } + return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) + } + util.GetLogger(ctx).WithError(err).Errorf(msg, args...) + return jsonerror.InternalServerError() +} + +func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSets, + } +} + +func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSet, + } +} + +func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: *rulesPtr, + } +} + +func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: (*rulesPtr)[i], + } +} + +func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newRule); err != nil { + return errorResponse(ctx, err, "JSON Decode failed") + } + newRule.RuleID = ruleID + + errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule) + if len(errs) > 0 { + return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i >= 0 && afterRuleID == "" && beforeRuleID == "" { + // Modify rule at the same index. + + // TODO: The spec does not say what to do in this case, but + // this feels reasonable. + *((*rulesPtr)[i]) = newRule + util.GetLogger(ctx).Infof("Modified existing push rule at %d", i) + } else { + if i >= 0 { + // Delete old rule. + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + util.GetLogger(ctx).Infof("Deleted old push rule at %d", i) + } else { + // SPEC: When creating push rules, they MUST be enabled by default. + // + // TODO: it's unclear if we must reject disabled rules, or force + // the value to true. Sytests fail if we don't force it. + newRule.Enabled = true + } + + // Add new rule. + i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) + if err != nil { + return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") + } + + *rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...) + util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) + } + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]interface{}{ + attr: attrGet((*rulesPtr)[i]), + }, + } +} + +func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newPartialRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + } + if newPartialRule.Actions == nil { + // This ensures json.Marshal encodes the empty list as [] rather than null. + newPartialRule.Actions = []*pushrules.Action{} + } + + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + attrSet, err := pushRuleAttrSetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrSetter failed") + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { + attrSet((*rulesPtr)[i], &newPartialRule) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) { + var res userapi.QueryPushRulesResponse + if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") + return nil, err + } + return res.RuleSets, nil +} + +func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error { + req := userapi.PerformPushRulesPutRequest{ + UserID: userID, + RuleSets: ruleSets, + } + var res struct{} + if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed") + return err + } + return nil +} + +func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { + switch scope { + case pushrules.GlobalScope: + return &ruleSets.Global + default: + return nil + } +} + +func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule { + switch kind { + case pushrules.OverrideKind: + return &ruleSet.Override + case pushrules.ContentKind: + return &ruleSet.Content + case pushrules.RoomKind: + return &ruleSet.Room + case pushrules.SenderKind: + return &ruleSet.Sender + case pushrules.UnderrideKind: + return &ruleSet.Underride + default: + return nil + } +} + +func pushRuleIndexByID(rules []*pushrules.Rule, id string) int { + for i, rule := range rules { + if rule.RuleID == id { + return i + } + } + return -1 +} + +func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) { + switch attr { + case "actions": + return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil + case "enabled": + return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) { + switch attr { + case "actions": + return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil + case "enabled": + return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) { + var i int + + if afterID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == afterID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("after: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("after: rule ID must not be a default rule") + } + // We stopped on the "after" match to differentiate + // not-found from is-last-entry. Now we move to the earliest + // insertion point. + i++ + } + + if beforeID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == beforeID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("before: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("before: rule ID must not be a default rule") + } + } + + // UNSPEC: The spec does not say what to do if no after/before is + // given. Sytest fails if it doesn't go first. + return i, nil +} diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 2b6bfa83c..5754afd86 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -76,6 +76,10 @@ type sessionsDict struct { sessions map[string][]authtypes.LoginType params map[string]registerRequest timer map[string]*time.Timer + // deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused. + // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, + // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. + deleteSessionToDeviceID map[string]string } // defaultTimeout is the timeout used to clean up sessions @@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) { defer d.Unlock() delete(d.params, sessionID) delete(d.sessions, sessionID) + delete(d.deleteSessionToDeviceID, sessionID) // stop the timer, e.g. because the registration was completed if t, ok := d.timer[sessionID]; ok { if !t.Stop() { @@ -129,9 +134,10 @@ func (d *sessionsDict) deleteSession(sessionID string) { func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), } } @@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } +func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.deleteSessionToDeviceID[sessionID] = deviceID +} + +func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { + d.RLock() + defer d.RUnlock() + deviceID, ok := d.deleteSessionToDeviceID[sessionID] + return deviceID, ok +} + var ( sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 5df58fc35..0507116f1 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) { s := newSessionsDict() t.Run("session is cleaned up after a while", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld" // manually added, as s.addParams() would start the timer with the default timeout s.params[dummySession] = registerRequest{Username: "Testing"} s.startTimer(time.Millisecond, dummySession) - time.Sleep(time.Millisecond * 5) + time.Sleep(time.Millisecond * 50) if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } }) t.Run("session is deleted, once the registration completed", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld2" s.startTimer(time.Minute, dummySession) s.deleteSession(dummySession) @@ -236,18 +236,28 @@ func TestSessionCleanUp(t *testing.T) { }) t.Run("session timer is restarted after second call", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld3" // the following will start a timer with the default timeout of 5min s.addParams(dummySession, registerRequest{Username: "Testing"}) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.addDeviceToDelete(dummySession, "dummyDevice") s.getCompletedStages(dummySession) // reset the timer with a lower timeout s.startTimer(time.Millisecond, dummySession) - time.Sleep(time.Millisecond * 5) + time.Sleep(time.Millisecond * 50) if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } + if _, ok := s.timer[dummySession]; ok { + t.Error("expected timer to be delete") + } + if _, ok := s.sessions[dummySession]; ok { + t.Error("expected session to be delete") + } + if _, ok := s.getDeviceToDelete(dummySession); ok { + t.Error("expected session to device to be delete") + } }) -} \ No newline at end of file +} diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 85effcddc..beeefa908 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -99,7 +99,7 @@ func PutTag( return jsonerror.InternalServerError() } - if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } @@ -152,7 +152,7 @@ func DeleteTag( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 06243cd4b..444024b02 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -16,7 +16,6 @@ package routing import ( "context" - "encoding/json" "net/http" "strings" @@ -581,25 +580,142 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - v3mux.Handle("/pushrules/", - httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { - // TODO: Implement push rules API - res := json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`) + // Push rules + + v3mux.Handle("/pushrules", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ - Code: http.StatusOK, - JSON: &res, + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash"), } }), ).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetAllPushRules(req.Context(), device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + query := req.URL.Query() + return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodDelete) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + // Element user settings v3mux.Handle("/profile/{userID}", @@ -905,6 +1021,27 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/notifications", + httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetNotifications(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers", + httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetPushers(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers/set", + httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SetPusher(req, device, userAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub implementations for sytest v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index d6edead80..a5dd50861 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -48,24 +48,27 @@ Example: # read password from stdin %s --config dendrite.yaml -username alice -passwordstdin < my.pass cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin + # reset password for a user, can be used with a combination above to read the password + %s --config dendrite.yaml -reset-password -username alice -password foobarbaz Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") ) func main() { name := os.Args[0] flag.Usage = func() { - _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -93,6 +96,19 @@ func main() { if *isAdmin { accType = api.AccountTypeAdmin } + + if *resetPassword { + err = accountDB.SetPassword(context.Background(), *username, pass) + if err != nil { + logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error()) + } + if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil { + logrus.Fatalf("Failed to remove all devices: %s", err.Error()) + } + logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) + return + } + policyVersion := "" if cfg.Global.UserConsentOptions.Enabled { policyVersion = cfg.Global.UserConsentOptions.Version diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 78536901c..8ce641914 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -144,12 +144,14 @@ func main() { accountDB := base.Base.CreateAccountsDB() federation := createFederationClient(base) keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( &base.Base, ) + + userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( &base.Base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 5810a7f18..45f186985 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -187,7 +187,7 @@ func main() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index d16f0e9e5..b7e30ba2e 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -111,14 +111,15 @@ func main() { keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( base, ) rsAPI := rsComponent + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index bb2685208..3b952504b 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -106,7 +106,8 @@ func main() { keyAPI = base.KeyServerHTTPClient() } - userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + pgClient := base.PushGatewayHTTPClient() + userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index f147cda14..f1fa379c7 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -23,7 +23,11 @@ import ( func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { accountDB := base.CreateAccountsDB() - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) + userAPI := userapi.NewInternalAPI( + base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, + base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), + base.PushGatewayHTTPClient(), + ) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 664f644f3..407081f59 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -184,13 +184,15 @@ func startup() { accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, pSessions) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() rsAPI := roomserver.NewInternalAPI(base) + + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 0ea41b4c4..37cbb12dd 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -212,6 +212,8 @@ func main() { rsAPI.SetFederationAPI(fedSenderAPI, keyRing) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) + psAPI := pushserver.NewInternalAPI(base) + monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, @@ -225,6 +227,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + PushserverAPI: psAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 66dc5e049..c537bd1ef 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -6,7 +6,7 @@ # # At a minimum, to get started, you will need to update the settings in the # "global" section for your deployment, and you will need to check that the -# database "connection_string" line in each component section is correct. +# database "connection_string" line in each component section is correct. # # Each component with a "database" section can accept the following formats # for "connection_string": @@ -21,13 +21,13 @@ # small number of users and likely will perform worse still with a higher volume # of users. # -# The "max_open_conns" and "max_idle_conns" settings configure the maximum +# The "max_open_conns" and "max_idle_conns" settings configure the maximum # number of open/idle database connections. The value 0 will use the database # engine default, and a negative value will use unlimited connections. The # "conn_max_lifetime" option controls the maximum length of time a database # connection can be idle in seconds - a negative value is unlimited. -# The version of the configuration file. +# The version of the configuration file. version: 2 # Global Matrix configuration. This configuration applies to all components. @@ -61,8 +61,8 @@ global: # Lists of domains that the server will trust as identity servers to verify third # party identifiers such as phone numbers and email addresses. trusted_third_party_id_servers: - - matrix.org - - vector.im + - matrix.org + - vector.im # Disables federation. Dendrite will not be able to make any outbound HTTP requests # to other servers and the federation API will not be exposed. @@ -116,14 +116,14 @@ global: # in monolith mode. It is required to specify the address of at least one # NATS Server node if running in polylith mode. addresses: - # - localhost:4222 + # - localhost:4222 # Keep all NATS streams in memory, rather than persisting it to the storage # path below. This option is present primarily for integration testing and # should not be used on a real world Dendrite deployment. in_memory: false - # Persistent directory to store JetStream streams in. This directory + # Persistent directory to store JetStream streams in. This directory # should be preserved across Dendrite restarts. storage_path: ./ @@ -155,7 +155,7 @@ global: # Configuration for the Appservice API. app_service_api: internal_api: - listen: http://localhost:7777 # Only used in polylith deployments + listen: http://localhost:7777 # Only used in polylith deployments connect: http://localhost:7777 # Only used in polylith deployments database: connection_string: file:appservice.db @@ -174,7 +174,7 @@ app_service_api: # Configuration for the Client API. client_api: internal_api: - listen: http://localhost:7771 # Only used in polylith deployments + listen: http://localhost:7771 # Only used in polylith deployments connect: http://localhost:7771 # Only used in polylith deployments external_api: listen: http://[::]:8071 @@ -194,13 +194,13 @@ client_api: # Whether to require reCAPTCHA for registration. enable_registration_captcha: false - # Settings for ReCAPTCHA. + # Settings for ReCAPTCHA. recaptcha_public_key: "" recaptcha_private_key: "" recaptcha_bypass_secret: "" recaptcha_siteverify_api: "" - # TURN server information that this homeserver should send to clients. + # TURN server information that this homeserver should send to clients. turn: turn_user_lifetime: "" turn_uris: [] @@ -209,7 +209,7 @@ client_api: turn_password: "" # Settings for rate-limited endpoints. Rate limiting will kick in after the - # threshold number of "slots" have been taken by requests from a specific + # threshold number of "slots" have been taken by requests from a specific # host. Each "slot" will be released after the cooloff time in milliseconds. rate_limiting: enabled: true @@ -219,13 +219,13 @@ client_api: # Configuration for the EDU server. edu_server: internal_api: - listen: http://localhost:7778 # Only used in polylith deployments + listen: http://localhost:7778 # Only used in polylith deployments connect: http://localhost:7778 # Only used in polylith deployments # Configuration for the Federation API. federation_api: internal_api: - listen: http://localhost:7772 # Only used in polylith deployments + listen: http://localhost:7772 # Only used in polylith deployments connect: http://localhost:7772 # Only used in polylith deployments external_api: listen: http://[::]:8072 @@ -253,12 +253,12 @@ federation_api: # be required to satisfy key requests for servers that are no longer online when # joining some rooms. key_perspectives: - - server_name: matrix.org - keys: - - key_id: ed25519:auto - public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw - - key_id: ed25519:a_RXGa - public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ # This option will control whether Dendrite will prefer to look up keys directly # or whether it should try perspective servers first, using direct fetches as a @@ -268,7 +268,7 @@ federation_api: # Configuration for the Key Server (for end-to-end encryption). key_server: internal_api: - listen: http://localhost:7779 # Only used in polylith deployments + listen: http://localhost:7779 # Only used in polylith deployments connect: http://localhost:7779 # Only used in polylith deployments database: connection_string: file:keyserver.db @@ -279,7 +279,7 @@ key_server: # Configuration for the Media API. media_api: internal_api: - listen: http://localhost:7774 # Only used in polylith deployments + listen: http://localhost:7774 # Only used in polylith deployments connect: http://localhost:7774 # Only used in polylith deployments external_api: listen: http://[::]:8074 @@ -305,15 +305,15 @@ media_api: # A list of thumbnail sizes to be generated for media content. thumbnail_sizes: - - width: 32 - height: 32 - method: crop - - width: 96 - height: 96 - method: crop - - width: 640 - height: 480 - method: scale + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale # Configuration for experimental MSC's mscs: @@ -331,7 +331,7 @@ mscs: # Configuration for the Room Server. room_server: internal_api: - listen: http://localhost:7770 # Only used in polylith deployments + listen: http://localhost:7770 # Only used in polylith deployments connect: http://localhost:7770 # Only used in polylith deployments database: connection_string: file:roomserver.db @@ -342,7 +342,7 @@ room_server: # Configuration for the Sync API. sync_api: internal_api: - listen: http://localhost:7773 # Only used in polylith deployments + listen: http://localhost:7773 # Only used in polylith deployments connect: http://localhost:7773 # Only used in polylith deployments external_api: listen: http://[::]:8073 @@ -367,21 +367,16 @@ user_api: # This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds) # bcrypt_cost: 10 internal_api: - listen: http://localhost:7781 # Only used in polylith deployments + listen: http://localhost:7781 # Only used in polylith deployments connect: http://localhost:7781 # Only used in polylith deployments account_database: connection_string: file:userapi_accounts.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - device_database: - connection_string: file:userapi_devices.db - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - # The length of time that a token issued for a relying party from + # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. + # is considered to be valid in milliseconds. # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 @@ -403,10 +398,10 @@ tracing: # Logging configuration logging: -- type: std - level: info -- type: file - # The logging level, must be one of debug, info, warn, error, fatal, panic. - level: info - params: - path: ./logs + - type: std + level: info + - type: file + # The logging level, must be one of debug, info, warn, error, fatal, panic. + level: info + params: + path: ./logs diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 731c6159b..237120ffb 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -13,6 +13,7 @@ Group=dendrite WorkingDirectory=/opt/dendrite/ ExecStart=/opt/dendrite/bin/dendrite-monolith-server Restart=always +LimitNOFILE=65535 [Install] WantedBy=multi-user.target diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f5ee75b4b..4d6b0211c 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,7 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b31db466c..b8bd5beda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) + return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index f9b2a33d2..01ca6595d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( } type spacesReq struct { - S gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2946SpacesRequest - RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError + S gomatrixserverlib.ServerName + SuggestedOnly bool + RoomID string + Res gomatrixserverlib.MSC2946SpacesResponse + Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") defer span.Finish() request := spacesReq{ - S: dst, - Req: r, - RoomID: roomID, + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, } var response spacesReq apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 8d193d9c9..ca4930f20 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) + res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) if err != nil { ferr, ok := err.(*api.FederationClientError) if ok { diff --git a/go.mod b/go.mod index a416ec98f..525950daa 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c @@ -18,12 +18,13 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/go-cmp v0.5.6 + github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect - github.com/klauspost/compress v1.14.2 // indirect github.com/lib/pq v1.10.4 github.com/libp2p/go-libp2p v0.13.0 github.com/libp2p/go-libp2p-circuit v0.4.0 @@ -39,12 +40,12 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect - github.com/nats-io/nats-server/v2 v2.3.2 + github.com/nats-io/nats-server/v2 v2.7.3 github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 @@ -61,11 +62,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index ecb0496bc..25697f59c 100644 --- a/go.sum +++ b/go.sum @@ -480,7 +480,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= @@ -712,9 +711,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= -github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4= +github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -983,8 +981,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1029,8 +1027,8 @@ github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7 github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ= -github.com/minio/highwayhash v1.0.1 h1:dZ6IIu8Z14VlC0VpfKofAhCy74wu/Qb5gcn52yWoz/0= -github.com/minio/highwayhash v1.0.1/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= +github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= @@ -1132,8 +1130,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= -github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= +github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 h1:RJrc+z35RHZlrjR6UBt9UmVRAlFh4SgYyEA0YpQdPHM= +github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740/go.mod h1:eJUrA5gm0ch6sJTEv85xmXIgQWsB0OyjkTsKXvlHbYc= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= @@ -1510,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1737,8 +1735,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= -golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/internal/caching/cache_federationevents.go b/internal/caching/cache_federationevents.go index d10b333a6..b79cc809f 100644 --- a/internal/caching/cache_federationevents.go +++ b/internal/caching/cache_federationevents.go @@ -10,6 +10,7 @@ const ( FederationEventCacheName = "federation_event" FederationEventCacheMaxEntries = 256 FederationEventCacheMutable = true // to allow use of Unset only + FederationEventCacheMaxAge = CacheNoMaxAge ) // FederationCache contains the subset of functions needed for diff --git a/internal/caching/cache_roominfo.go b/internal/caching/cache_roominfo.go index f32d6ba9b..60d221285 100644 --- a/internal/caching/cache_roominfo.go +++ b/internal/caching/cache_roominfo.go @@ -1,6 +1,8 @@ package caching import ( + "time" + "github.com/matrix-org/dendrite/roomserver/types" ) @@ -16,6 +18,7 @@ const ( RoomInfoCacheName = "roominfo" RoomInfoCacheMaxEntries = 1024 RoomInfoCacheMutable = true + RoomInfoCacheMaxAge = time.Minute * 5 ) // RoomInfosCache contains the subset of functions needed for diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 6d413093f..1918a2f1e 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -10,6 +10,7 @@ const ( RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false + RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge ) type RoomServerCaches interface { diff --git a/internal/caching/cache_roomversions.go b/internal/caching/cache_roomversions.go index 0b46d3d4b..92d2eab08 100644 --- a/internal/caching/cache_roomversions.go +++ b/internal/caching/cache_roomversions.go @@ -6,6 +6,7 @@ const ( RoomVersionCacheName = "room_versions" RoomVersionCacheMaxEntries = 1024 RoomVersionCacheMutable = false + RoomVersionCacheMaxAge = CacheNoMaxAge ) // RoomVersionsCache contains the subset of functions needed for diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 4697fb4d2..4eb10fe6f 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -10,6 +10,7 @@ const ( ServerKeyCacheName = "server_key" ServerKeyCacheMaxEntries = 4096 ServerKeyCacheMutable = true + ServerKeyCacheMaxAge = CacheNoMaxAge ) // ServerKeyCache contains the subset of functions needed for diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go new file mode 100644 index 000000000..6d56cce5f --- /dev/null +++ b/internal/caching/cache_space_rooms.go @@ -0,0 +1,33 @@ +package caching + +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + SpaceSummaryRoomsCacheName = "space_summary_rooms" + SpaceSummaryRoomsCacheMaxEntries = 100 + SpaceSummaryRoomsCacheMutable = true + SpaceSummaryRoomsCacheMaxAge = time.Minute * 5 +) + +type SpaceSummaryRoomsCache interface { + GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) + StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) +} + +func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) { + val, found := c.SpaceSummaryRooms.Get(roomID) + if found && val != nil { + if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok { + return resp, true + } + } + return r, false +} + +func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) { + c.SpaceSummaryRooms.Set(roomID, r) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go index e1642a663..722405de6 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -1,5 +1,7 @@ package caching +import "time" + // Caches contains a set of references to caches. They may be // different implementations as long as they satisfy the Cache // interface. @@ -10,6 +12,7 @@ type Caches struct { RoomServerRoomIDs Cache // RoomServerNIDsCache RoomInfos Cache // RoomInfoCache FederationEvents Cache // FederationEventsCache + SpaceSummaryRooms Cache // SpaceSummaryRoomsCache } // Cache is the interface that an implementation must satisfy. @@ -18,3 +21,5 @@ type Cache interface { Set(key string, value interface{}) Unset(key string) } + +const CacheNoMaxAge = time.Duration(0) diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index ccb92852b..94fdd1a9b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomVersionCacheName, RoomVersionCacheMutable, RoomVersionCacheMaxEntries, + RoomVersionCacheMaxAge, enablePrometheus, ) if err != nil { @@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { ServerKeyCacheName, ServerKeyCacheMutable, ServerKeyCacheMaxEntries, + ServerKeyCacheMaxAge, enablePrometheus, ) if err != nil { @@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMaxEntries, + RoomServerRoomIDsCacheMaxAge, enablePrometheus, ) if err != nil { @@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomInfoCacheName, RoomInfoCacheMutable, RoomInfoCacheMaxEntries, + RoomInfoCacheMaxAge, enablePrometheus, ) if err != nil { @@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { FederationEventCacheName, FederationEventCacheMutable, FederationEventCacheMaxEntries, + FederationEventCacheMaxAge, + enablePrometheus, + ) + if err != nil { + return nil, err + } + spaceRooms, err := NewInMemoryLRUCachePartition( + SpaceSummaryRoomsCacheName, + SpaceSummaryRoomsCacheMutable, + SpaceSummaryRoomsCacheMaxEntries, + SpaceSummaryRoomsCacheMaxAge, enablePrometheus, ) if err != nil { @@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { } go cacheCleaner( roomVersions, serverKeys, roomServerRoomIDs, - roomInfos, federationEvents, + roomInfos, federationEvents, spaceRooms, ) return &Caches{ RoomVersions: roomVersions, @@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDs: roomServerRoomIDs, RoomInfos: roomInfos, FederationEvents: federationEvents, + SpaceSummaryRooms: spaceRooms, }, nil } @@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct { name string mutable bool maxEntries int + maxAge time.Duration lru *lru.Cache } -func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { +type inMemoryLRUCacheEntry struct { + value interface{} + created time.Time +} + +func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { var err error cache := InMemoryLRUCachePartition{ name: name, mutable: mutable, maxEntries: maxEntries, + maxAge: maxAge, } cache.lru, err = lru.New(maxEntries) if err != nil { @@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { if !c.mutable { - if peek, ok := c.lru.Peek(key); ok && peek != value { - panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + if peek, ok := c.lru.Peek(key); ok { + if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value { + panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + } } } - c.lru.Add(key, value) + c.lru.Add(key, &inMemoryLRUCacheEntry{ + value: value, + created: time.Now(), + }) } func (c *InMemoryLRUCachePartition) Unset(key string) { @@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) { } func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { - return c.lru.Get(key) + v, ok := c.lru.Get(key) + if !ok { + return nil, false + } + entry, ok := v.(*inMemoryLRUCacheEntry) + switch { + case ok && c.maxAge == CacheNoMaxAge: + return entry.value, ok // There's no maximum age policy + case ok && time.Since(entry.created) < c.maxAge: + return entry.value, ok // The value for the key isn't stale + default: + // Either the key was found and it was stale, or the key + // wasn't found at all + c.lru.Remove(key) + return nil, false + } } diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go index 6d119ce6d..17861d6c5 100644 --- a/internal/eventutil/types.go +++ b/internal/eventutil/types.go @@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID") // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { + RoomID string `json:"room_id"` + Type string `json:"type"` + ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional +} + +type ReadMarkerJSON struct { + FullyRead string `json:"m.fully_read"` + Read string `json:"m.read"` +} + +// NotificationData contains statistics about notifications, sent from +// the Push Server to the Sync API server. +type NotificationData struct { + // RoomID identifies the scope of the statistics, together with + // MXID (which is encoded in the Kafka key). RoomID string `json:"room_id"` - Type string `json:"type"` + + // HighlightCount is the number of unread notifications with the + // highlight tweak. + UnreadHighlightCount int `json:"unread_highlight_count"` + + // UnreadNotificationCount is the total number of unread + // notifications. + UnreadNotificationCount int `json:"unread_notification_count"` } // ProfileResponse is a struct containing all known user profile data diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go new file mode 100644 index 000000000..49907cee8 --- /dev/null +++ b/internal/pushgateway/client.go @@ -0,0 +1,66 @@ +package pushgateway + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/opentracing/opentracing-go" +) + +type httpClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new Push Gateway client. +func NewHTTPClient(disableTLSValidation bool) Client { + hc := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: disableTLSValidation, + }, + }, + } + return &httpClient{hc: hc} +} + +func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") + defer span.Finish() + + body, err := json.Marshal(req) + if err != nil { + return err + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + hreq.Header.Set("Content-Type", "application/json") + + hresp, err := h.hc.Do(hreq) + if err != nil { + return err + } + + //nolint:errcheck + defer hresp.Body.Close() + + if hresp.StatusCode == http.StatusOK { + return json.NewDecoder(hresp.Body).Decode(resp) + } + + var errorBody struct { + Message string `json:"message"` + } + if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil { + return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message) + } + return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url) +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go new file mode 100644 index 000000000..88c326eb2 --- /dev/null +++ b/internal/pushgateway/pushgateway.go @@ -0,0 +1,62 @@ +package pushgateway + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Client is how interactions with a Push Gateway is done. +type Client interface { + // Notify sends a notification to the gateway at the given URL. + Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error +} + +type NotifyRequest struct { + Notification Notification `json:"notification"` // Required +} + +type NotifyResponse struct { + // Rejected is the list of device push keys that were rejected + // during the push. The caller should remove the push keys so they + // are not used again. + Rejected []string `json:"rejected"` // Required +} + +type Notification struct { + Content json.RawMessage `json:"content,omitempty"` + Counts *Counts `json:"counts,omitempty"` + Devices []*Device `json:"devices"` // Required + EventID string `json:"event_id,omitempty"` + ID string `json:"id,omitempty"` // Deprecated name for EventID. + Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest. + Prio Prio `json:"prio,omitempty"` + RoomAlias string `json:"room_alias,omitempty"` + RoomID string `json:"room_id,omitempty"` + RoomName string `json:"room_name,omitempty"` + Sender string `json:"sender,omitempty"` + SenderDisplayName string `json:"sender_display_name,omitempty"` + Type string `json:"type,omitempty"` + UserIsTarget bool `json:"user_is_target,omitempty"` +} + +type Counts struct { + MissedCalls int `json:"missed_calls,omitempty"` + Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it. +} + +type Device struct { + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` +} + +type Prio string + +const ( + HighPrio Prio = "high" + LowPrio Prio = "low" +) diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go new file mode 100644 index 000000000..c7b8cec83 --- /dev/null +++ b/internal/pushrules/action.go @@ -0,0 +1,102 @@ +package pushrules + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// An Action is (part of) an outcome of a rule. There are +// (unofficially) terminal actions, and modifier actions. +type Action struct { + // Kind is the type of action. Has custom encoding in JSON. + Kind ActionKind `json:"-"` + + // Tweak is the property to tweak. Has custom encoding in JSON. + Tweak TweakKey `json:"-"` + + // Value is some value interpreted according to Kind and Tweak. + Value interface{} `json:"value"` +} + +func (a *Action) MarshalJSON() ([]byte, error) { + if a.Tweak == UnknownTweak && a.Value == nil { + return json.Marshal(a.Kind) + } + + if a.Kind != SetTweakAction { + return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind) + } + + m := map[string]interface{}{ + string(a.Kind): a.Tweak, + } + if a.Value != nil { + m["value"] = a.Value + } + + return json.Marshal(m) +} + +func (a *Action) UnmarshalJSON(bs []byte) error { + if bytes.HasPrefix(bs, []byte("\"")) { + return json.Unmarshal(bs, &a.Kind) + } + + var raw struct { + SetTweak TweakKey `json:"set_tweak"` + Value interface{} `json:"value"` + } + if err := json.Unmarshal(bs, &raw); err != nil { + return err + } + if raw.SetTweak == UnknownTweak { + return fmt.Errorf("got unknown action JSON: %s", string(bs)) + } + a.Kind = SetTweakAction + a.Tweak = raw.SetTweak + if raw.Value != nil { + a.Value = raw.Value + } + + return nil +} + +// ActionKind is the primary discriminator for actions. +type ActionKind string + +const ( + UnknownAction ActionKind = "" + + // NotifyAction indicates the clients should show a notification. + NotifyAction ActionKind = "notify" + + // DontNotifyAction indicates the clients should not show a notification. + DontNotifyAction ActionKind = "dont_notify" + + // CoalesceAction tells the clients to show a notification, and + // tells both servers and clients that multiple events can be + // coalesced into a single notification. The behaviour is + // implementation-specific. + CoalesceAction ActionKind = "coalesce" + + // SetTweakAction uses the Tweak and Value fields to add a + // tweak. Multiple SetTweakAction can be provided in a rule, + // combined with NotifyAction or CoalesceAction. + SetTweakAction ActionKind = "set_tweak" +) + +// A TweakKey describes a property to be modified/tweaked for events +// that match the rule. +type TweakKey string + +const ( + UnknownTweak TweakKey = "" + + // SoundTweak describes which sound to play. Using "default" means + // "enable sound". + SoundTweak TweakKey = "sound" + + // HighlightTweak asks the clients to highlight the conversation. + HighlightTweak TweakKey = "highlight" +) diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go new file mode 100644 index 000000000..72db9c998 --- /dev/null +++ b/internal/pushrules/action_test.go @@ -0,0 +1,39 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionJSON(t *testing.T) { + tsts := []struct { + Want Action + }{ + {Action{Kind: NotifyAction}}, + {Action{Kind: DontNotifyAction}}, + {Action{Kind: CoalesceAction}}, + {Action{Kind: SetTweakAction}}, + + {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) { + bs, err := json.Marshal(&tst.Want) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got Action + if err := json.Unmarshal(bs, &got); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go new file mode 100644 index 000000000..2d9773c0f --- /dev/null +++ b/internal/pushrules/condition.go @@ -0,0 +1,49 @@ +package pushrules + +// A Condition dictates extra conditions for a matching rules. See +// ConditionKind. +type Condition struct { + // Kind is the primary discriminator for the condition + // type. Required. + Kind ConditionKind `json:"kind"` + + // Key indicates the dot-separated path of Event fields to + // match. Required for EventMatchCondition and + // SenderNotificationPermissionCondition. + Key string `json:"key,omitempty"` + + // Pattern indicates the value pattern that must match. Required + // for EventMatchCondition. + Pattern string `json:"pattern,omitempty"` + + // Is indicates the condition that must be fulfilled. Required for + // RoomMemberCountCondition. + Is string `json:"is,omitempty"` +} + +// ConditionKind represents a kind of condition. +// +// SPEC: Unrecognised conditions MUST NOT match any events, +// effectively making the push rule disabled. +type ConditionKind string + +const ( + UnknownCondition ConditionKind = "" + + // EventMatchCondition indicates the condition looks for a key + // path and matches a pattern. How paths that don't reference a + // simple value match against rules is implementation-specific. + EventMatchCondition ConditionKind = "event_match" + + // ContainsDisplayNameCondition indicates the current user's + // display name must be found in the content body. + ContainsDisplayNameCondition ConditionKind = "contains_display_name" + + // RoomMemberCountCondition matches a simple arithmetic comparison + // against the total number of members in a room. + RoomMemberCountCondition ConditionKind = "room_member_count" + + // SenderNotificationPermissionCondition compares power level for + // the sender in the event's room. + SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission" +) diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go new file mode 100644 index 000000000..996985514 --- /dev/null +++ b/internal/pushrules/default.go @@ -0,0 +1,23 @@ +package pushrules + +import ( + "github.com/matrix-org/gomatrixserverlib" +) + +// DefaultAccountRuleSets is the complete set of default push rules +// for an account. +func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { + return &AccountRuleSets{ + Global: *DefaultGlobalRuleSet(localpart, serverName), + } +} + +// DefaultGlobalRuleSet returns the default ruleset for a given (fully +// qualified) MXID. +func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { + return &RuleSet{ + Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), + Content: defaultContentRules(localpart), + Underride: defaultUnderrideRules, + } +} diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go new file mode 100644 index 000000000..158afd18b --- /dev/null +++ b/internal/pushrules/default_content.go @@ -0,0 +1,33 @@ +package pushrules + +func defaultContentRules(localpart string) []*Rule { + return []*Rule{ + mRuleContainsUserNameDefinition(localpart), + } +} + +const ( + MRuleContainsUserName = ".m.rule.contains_user_name" +) + +func mRuleContainsUserNameDefinition(localpart string) *Rule { + return &Rule{ + RuleID: MRuleContainsUserName, + Default: true, + Enabled: true, + Pattern: localpart, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } +} diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go new file mode 100644 index 000000000..6f66fd66a --- /dev/null +++ b/internal/pushrules/default_override.go @@ -0,0 +1,165 @@ +package pushrules + +func defaultOverrideRules(userID string) []*Rule { + return []*Rule{ + &mRuleMasterDefinition, + &mRuleSuppressNoticesDefinition, + mRuleInviteForMeDefinition(userID), + &mRuleMemberEventDefinition, + &mRuleContainsDisplayNameDefinition, + &mRuleTombstoneDefinition, + &mRuleRoomNotifDefinition, + } +} + +const ( + MRuleMaster = ".m.rule.master" + MRuleSuppressNotices = ".m.rule.suppress_notices" + MRuleInviteForMe = ".m.rule.invite_for_me" + MRuleMemberEvent = ".m.rule.member_event" + MRuleContainsDisplayName = ".m.rule.contains_display_name" + MRuleTombstone = ".m.rule.tombstone" + MRuleRoomNotif = ".m.rule.roomnotif" +) + +var ( + mRuleMasterDefinition = Rule{ + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Conditions: []*Condition{}, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleSuppressNoticesDefinition = Rule{ + RuleID: MRuleSuppressNotices, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.msgtype", + Pattern: "m.notice", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleMemberEventDefinition = Rule{ + RuleID: MRuleMemberEvent, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleContainsDisplayNameDefinition = Rule{ + RuleID: MRuleContainsDisplayName, + Default: true, + Enabled: true, + Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}}, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } + mRuleTombstoneDefinition = Rule{ + RuleID: MRuleTombstone, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.tombstone", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: "", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomNotifDefinition = Rule{ + RuleID: MRuleRoomNotif, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.body", + Pattern: "@room", + }, + { + Kind: SenderNotificationPermissionCondition, + Key: "room", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +) + +func mRuleInviteForMeDefinition(userID string) *Rule { + return &Rule{ + RuleID: MRuleInviteForMe, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + { + Kind: EventMatchCondition, + Key: "content.membership", + Pattern: "invite", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: userID, + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go new file mode 100644 index 000000000..de72bd526 --- /dev/null +++ b/internal/pushrules/default_underride.go @@ -0,0 +1,119 @@ +package pushrules + +const ( + MRuleCall = ".m.rule.call" + MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one" + MRuleRoomOneToOne = ".m.rule.room_one_to_one" + MRuleMessage = ".m.rule.message" + MRuleEncrypted = ".m.rule.encrypted" +) + +var defaultUnderrideRules = []*Rule{ + &mRuleCallDefinition, + &mRuleEncryptedRoomOneToOneDefinition, + &mRuleRoomOneToOneDefinition, + &mRuleMessageDefinition, + &mRuleEncryptedDefinition, +} + +var ( + mRuleCallDefinition = Rule{ + RuleID: MRuleCall, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.call.invite", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "ring", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleEncryptedRoomOneToOneDefinition = Rule{ + RuleID: MRuleEncryptedRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomOneToOneDefinition = Rule{ + RuleID: MRuleRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleMessageDefinition = Rule{ + RuleID: MRuleMessage, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } + mRuleEncryptedDefinition = Rule{ + RuleID: MRuleEncrypted, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } +) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go new file mode 100644 index 000000000..df22cb042 --- /dev/null +++ b/internal/pushrules/evaluate.go @@ -0,0 +1,165 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A RuleSetEvaluator encapsulates context to evaluate an event +// against a rule set. +type RuleSetEvaluator struct { + ec EvaluationContext + ruleSet []kindAndRules +} + +// An EvaluationContext gives a RuleSetEvaluator access to the +// environment, for rules that require that. +type EvaluationContext interface { + // UserDisplayName returns the current user's display name. + UserDisplayName() string + + // RoomMemberCount returns the number of members in the room of + // the current event. + RoomMemberCount() (int, error) + + // HasPowerLevel returns whether the user has at least the given + // power in the room of the current event. + HasPowerLevel(userID, levelKey string) (bool, error) +} + +// A kindAndRules is just here to simplify iteration of the (ordered) +// kinds of rules. +type kindAndRules struct { + Kind Kind + Rules []*Rule +} + +// NewRuleSetEvaluator creates a new evaluator for the given rule set. +func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator { + return &RuleSetEvaluator{ + ec: ec, + ruleSet: []kindAndRules{ + {OverrideKind, ruleSet.Override}, + {ContentKind, ruleSet.Content}, + {RoomKind, ruleSet.Room}, + {SenderKind, ruleSet.Sender}, + {UnderrideKind, ruleSet.Underride}, + }, + } +} + +// MatchEvent returns the first matching rule. Returns nil if there +// was no match rule. +func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { + // TODO: server-default rules have lower priority than user rules, + // but they are stored together with the user rules. It's a bit + // unclear what the specification (11.14.1.4 Predefined rules) + // means the ordering should be. + // + // The most reasonable interpretation is that default overrides + // still have lower priority than user content rules, so we + // iterate twice. + for _, rsat := range rse.ruleSet { + for _, defRules := range []bool{false, true} { + for _, rule := range rsat.Rules { + if rule.Default != defRules { + continue + } + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + if err != nil { + return nil, err + } + if ok { + return rule, nil + } + } + } + } + + // No matching rule. + return nil, nil +} + +func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + if !rule.Enabled { + return false, nil + } + + switch kind { + case OverrideKind, UnderrideKind: + for _, cond := range rule.Conditions { + ok, err := conditionMatches(cond, event, ec) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + return true, nil + + case ContentKind: + // TODO: "These configure behaviour for (unencrypted) messages + // that match certain patterns." - Does that mean "content.body"? + return patternMatches("content.body", rule.Pattern, event) + + case RoomKind: + return rule.RuleID == event.RoomID(), nil + + case SenderKind: + return rule.RuleID == event.Sender(), nil + + default: + return false, nil + } +} + +func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + switch cond.Kind { + case EventMatchCondition: + return patternMatches(cond.Key, cond.Pattern, event) + + case ContainsDisplayNameCondition: + return patternMatches("content.body", ec.UserDisplayName(), event) + + case RoomMemberCountCondition: + cmp, err := parseRoomMemberCountCondition(cond.Is) + if err != nil { + return false, fmt.Errorf("parsing room_member_count condition: %w", err) + } + n, err := ec.RoomMemberCount() + if err != nil { + return false, fmt.Errorf("RoomMemberCount failed: %w", err) + } + return cmp(n), nil + + case SenderNotificationPermissionCondition: + return ec.HasPowerLevel(event.Sender(), cond.Key) + + default: + return false, nil + } +} + +func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + re, err := globToRegexp(pattern) + if err != nil { + return false, err + } + + var eventMap map[string]interface{} + if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { + return false, fmt.Errorf("parsing event: %w", err) + } + v, err := lookupMapPath(strings.Split(key, "."), eventMap) + if err != nil { + // An unknown path is a benign error that shouldn't stop rule + // processing. It's just a non-match. + return false, nil + } + + return re.MatchString(fmt.Sprint(v)), nil +} diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go new file mode 100644 index 000000000..50e703365 --- /dev/null +++ b/internal/pushrules/evaluate_test.go @@ -0,0 +1,189 @@ +package pushrules + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/matrix-org/gomatrixserverlib" +) + +func TestRuleSetEvaluatorMatchEvent(t *testing.T) { + ev := mustEventFromJSON(t, `{}`) + defaultEnabled := &Rule{ + RuleID: ".default.enabled", + Default: true, + Enabled: true, + } + userEnabled := &Rule{ + RuleID: ".user.enabled", + Default: false, + Enabled: true, + } + userEnabled2 := &Rule{ + RuleID: ".user.enabled.2", + Default: false, + Enabled: true, + } + tsts := []struct { + Name string + RuleSet RuleSet + Want *Rule + }{ + {"empty", RuleSet{}, nil}, + {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled}, + {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled}, + {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled}, + {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled}, + {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled}, + {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled}, + {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + rse := NewRuleSetEvaluator(nil, &tst.RuleSet) + got, err := rse.MatchEvent(ev) + if err != nil { + t.Fatalf("MatchEvent failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("MatchEvent rule: +got -want:\n%s", diff) + } + }) + } +} + +func TestRuleMatches(t *testing.T) { + emptyRule := Rule{Enabled: true} + tsts := []struct { + Name string + Kind Kind + Rule Rule + EventJSON string + Want bool + }{ + {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, + {"emptyContent", ContentKind, emptyRule, `{}`, false}, + {"emptyRoom", RoomKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, + + {"disabled", OverrideKind, Rule{}, `{}`, false}, + + {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true}, + {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, + {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + if err != nil { + t.Fatalf("ruleMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("ruleMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestConditionMatches(t *testing.T) { + tsts := []struct { + Name string + Cond Condition + EventJSON string + Want bool + }{ + {"empty", Condition{}, `{}`, false}, + {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + + {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, + {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + + {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, + {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, + {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, + {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, + {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, + {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, + {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, + {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, + {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, + {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + + {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, + {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{}) + if err != nil { + t.Fatalf("conditionMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +type fakeEvaluationContext struct{} + +func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } +func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil } +func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { + return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +} + +func TestPatternMatches(t *testing.T) { + tsts := []struct { + Name string + Key string + Pattern string + EventJSON string + Want bool + }{ + {"empty", "", "", `{}`, false}, + + // Note that an empty pattern contains no wildcard characters, + // which implicitly means "*". + {"patternEmpty", "content", "", `{"content":{}}`, true}, + + {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, + {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, + {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true}, + {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true}, + {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON)) + if err != nil { + t.Fatalf("patternMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7) + if err != nil { + t.Fatal(err) + } + return ev +} diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go new file mode 100644 index 000000000..bbed1f95f --- /dev/null +++ b/internal/pushrules/pushrules.go @@ -0,0 +1,71 @@ +package pushrules + +// An AccountRuleSets carries the rule sets associated with an +// account. +type AccountRuleSets struct { + Global RuleSet `json:"global"` // Required +} + +// A RuleSet contains all the various push rules for an +// account. Listed in decreasing order of priority. +type RuleSet struct { + Override []*Rule `json:"override,omitempty"` + Content []*Rule `json:"content,omitempty"` + Room []*Rule `json:"room,omitempty"` + Sender []*Rule `json:"sender,omitempty"` + Underride []*Rule `json:"underride,omitempty"` +} + +// A Rule contains matchers, conditions and final actions. While +// evaluating, at most one rule is considered matching. +// +// Kind and scope are part of the push rules request/responses, but +// not of the core data model. +type Rule struct { + // RuleID is either a free identifier, or the sender's MXID for + // SenderKind. Required. + RuleID string `json:"rule_id"` + + // Default indicates whether this is a server-defined default, or + // a user-provided rule. Required. + // + // The server-default rules have the lowest priority. + Default bool `json:"default"` + + // Enabled allows the user to disable rules while keeping them + // around. Required. + Enabled bool `json:"enabled"` + + // Actions describe the desired outcome, should the rule + // match. Required. + Actions []*Action `json:"actions"` + + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions"` + + // Pattern is the body pattern to match for ContentKind. Required + // for that kind. The interpretation is the same as that of + // Condition.Pattern. + Pattern string `json:"pattern"` +} + +// Scope only has one valid value. See also AccountRuleSets. +type Scope string + +const ( + UnknownScope Scope = "" + GlobalScope Scope = "global" +) + +// Kind is the type of push rule. See also RuleSet. +type Kind string + +const ( + UnknownKind Kind = "" + OverrideKind Kind = "override" + ContentKind Kind = "content" + RoomKind Kind = "room" + SenderKind Kind = "sender" + UnderrideKind Kind = "underride" +) diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go new file mode 100644 index 000000000..027d35ef6 --- /dev/null +++ b/internal/pushrules/util.go @@ -0,0 +1,125 @@ +package pushrules + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// ActionsToTweaks converts a list of actions into a primary action +// kind and a tweaks map. Returns a nil map if it would have been +// empty. +func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { + var kind ActionKind + tweaks := map[string]interface{}{} + + for _, a := range as { + if a.Kind == SetTweakAction { + tweaks[string(a.Tweak)] = a.Value + continue + } + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } + + if len(tweaks) == 0 { + tweaks = nil + } + + return kind, tweaks, nil +} + +// BoolTweakOr returns the named tweak as a boolean, and returns `def` +// on failure. +func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool { + v, ok := tweaks[string(key)] + if !ok { + return def + } + b, ok := v.(bool) + if !ok { + return def + } + return b +} + +// globToRegexp converts a Matrix glob-style pattern to a Regular expression. +func globToRegexp(pattern string) (*regexp.Regexp, error) { + // TODO: It's unclear which glob characters are supported. The only + // place this is discussed is for the unrelated "m.policy.rule.*" + // events. Assuming, the same: /[*?]/ + if !strings.ContainsAny(pattern, "*?") { + pattern = "*" + pattern + "*" + } + + // The defined syntax doesn't allow escaping the glob wildcard + // characters, which makes this a straight-forward + // replace-after-quote. + pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta) + pattern = strings.Replace(pattern, "*", ".*", -1) + pattern = strings.Replace(pattern, "?", ".", -1) + return regexp.Compile("^(" + pattern + ")$") +} + +// globNonMetaRegexp are the characters that are not considered glob +// meta-characters (i.e. may need escaping). +var globNonMetaRegexp = regexp.MustCompile("[^*?]+") + +// lookupMapPath traverses a hierarchical map structure, like the one +// produced by json.Unmarshal, to return the leaf value. Traversing +// arrays/slices is not supported, only objects/maps. +func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) { + if len(path) == 0 { + return nil, fmt.Errorf("empty path") + } + + var v interface{} = m + for i, key := range path { + m, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v) + } + + v, ok = m[key] + if !ok { + return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], ".")) + } + } + + return v, nil +} + +// parseRoomMemberCountCondition parses a string like "2", "==2", "<2" +// into a function that checks if the argument to it fulfils the +// condition. +func parseRoomMemberCountCondition(s string) (func(int) bool, error) { + var b int + var cmp = func(a int) bool { return a == b } + switch { + case strings.HasPrefix(s, "<="): + cmp = func(a int) bool { return a <= b } + s = s[2:] + case strings.HasPrefix(s, ">="): + cmp = func(a int) bool { return a >= b } + s = s[2:] + case strings.HasPrefix(s, "<"): + cmp = func(a int) bool { return a < b } + s = s[1:] + case strings.HasPrefix(s, ">"): + cmp = func(a int) bool { return a > b } + s = s[1:] + case strings.HasPrefix(s, "=="): + // Same cmp as the default. + s = s[2:] + } + + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + b = int(v) + return cmp, nil +} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go new file mode 100644 index 000000000..a951c55a2 --- /dev/null +++ b/internal/pushrules/util_test.go @@ -0,0 +1,169 @@ +package pushrules + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionsToTweaks(t *testing.T) { + tsts := []struct { + Name string + Input []*Action + WantKind ActionKind + WantTweaks map[string]interface{} + }{ + {"empty", nil, UnknownAction, nil}, + {"zero", []*Action{{}}, UnknownAction, nil}, + {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, + {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, + { + "all", + []*Action{ + {Kind: CoalesceAction}, + {Kind: SetTweakAction, Tweak: HighlightTweak}, + {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}, + }, + CoalesceAction, + map[string]interface{}{"highlight": nil, "sound": "default"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + gotKind, gotTweaks, err := ActionsToTweaks(tst.Input) + if err != nil { + t.Fatalf("ActionsToTweaks failed: %v", err) + } + if gotKind != tst.WantKind { + t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind) + } + if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" { + t.Errorf("tweaks: +got -want:\n%s", diff) + } + }) + } +} + +func TestBoolTweakOr(t *testing.T) { + tsts := []struct { + Name string + Input map[string]interface{} + Def bool + Want bool + }{ + {"nil", nil, false, false}, + {"nilValue", map[string]interface{}{"highlight": nil}, true, true}, + {"false", map[string]interface{}{"highlight": false}, true, false}, + {"true", map[string]interface{}{"highlight": true}, false, true}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def) + if got != tst.Want { + t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestGlobToRegexp(t *testing.T) { + tsts := []struct { + Input string + Want string + }{ + {"", "^(.*.*)$"}, + {"a", "^(.*a.*)$"}, + {"a.b", "^(.*a\\.b.*)$"}, + {"a?b", "^(a.b)$"}, + {"a*b*", "^(a.*b.*)$"}, + {"a*b?", "^(a.*b.)$"}, + } + for _, tst := range tsts { + t.Run(tst.Want, func(t *testing.T) { + got, err := globToRegexp(tst.Input) + if err != nil { + t.Fatalf("globToRegexp failed: %v", err) + } + if got.String() != tst.Want { + t.Errorf("got %v, want %v", got.String(), tst.Want) + } + }) + } +} + +func TestLookupMapPath(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + Want interface{} + }{ + {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"}, + {[]string{"a"}, map[string]interface{}{"a": 42}, 42}, + {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err != nil { + t.Fatalf("lookupMapPath failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} + +func TestLookupMapPathInvalid(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + }{ + {nil, nil}, + {[]string{"a"}, nil}, + {[]string{"a", "b"}, map[string]interface{}{"a": "c"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err == nil { + t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got) + } + }) + } +} + +func TestParseRoomMemberCountCondition(t *testing.T) { + tsts := []struct { + Input string + WantTrue []int + WantFalse []int + }{ + {"1", []int{1}, []int{0, 2}}, + {"==1", []int{1}, []int{0, 2}}, + {"<1", []int{0}, []int{1, 2}}, + {"<=1", []int{0, 1}, []int{2}}, + {">1", []int{2}, []int{0, 1}}, + {">=42", []int{42, 43}, []int{41}}, + } + for _, tst := range tsts { + t.Run(tst.Input, func(t *testing.T) { + got, err := parseRoomMemberCountCondition(tst.Input) + if err != nil { + t.Fatalf("parseRoomMemberCountCondition failed: %v", err) + } + for _, v := range tst.WantTrue { + if !got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v) + } + } + for _, v := range tst.WantFalse { + if got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v) + } + } + }) + } +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go new file mode 100644 index 000000000..5d260f0b9 --- /dev/null +++ b/internal/pushrules/validate.go @@ -0,0 +1,85 @@ +package pushrules + +import ( + "fmt" + "regexp" +) + +// ValidateRule checks the rule for errors. These follow from Sytests +// and the specification. +func ValidateRule(kind Kind, rule *Rule) []error { + var errs []error + + if !validRuleIDRE.MatchString(rule.RuleID) { + errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) + } + + if len(rule.Actions) == 0 { + errs = append(errs, fmt.Errorf("missing actions")) + } + for _, action := range rule.Actions { + errs = append(errs, validateAction(action)...) + } + + for _, cond := range rule.Conditions { + errs = append(errs, validateCondition(cond)...) + } + + switch kind { + case OverrideKind, UnderrideKind: + // The empty list is allowed, but for JSON-encoding reasons, + // it must not be nil. + if rule.Conditions == nil { + errs = append(errs, fmt.Errorf("missing rule conditions")) + } + + case ContentKind: + if rule.Pattern == "" { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + + case RoomKind, SenderKind: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind)) + } + + return errs +} + +// validRuleIDRE is a regexp for valid IDs. +// +// TODO: the specification doesn't seem to say what the rule ID syntax +// is. A Sytest fails if it contains a backslash. +var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`) + +// validateAction returns issues with an Action. +func validateAction(action *Action) []error { + var errs []error + + switch action.Kind { + case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind)) + } + + return errs +} + +// validateCondition returns issues with a Condition. +func validateCondition(cond *Condition) []error { + var errs []error + + switch cond.Kind { + case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind)) + } + + return errs +} diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go new file mode 100644 index 000000000..b276eb551 --- /dev/null +++ b/internal/pushrules/validate_test.go @@ -0,0 +1,163 @@ +package pushrules + +import ( + "strings" + "testing" +) + +func TestValidateRuleNegatives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantErrString string + }{ + {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, + {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, + {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, + {"noActions", OverrideKind, Rule{}, "missing actions"}, + {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, + {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, + {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, + {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, + {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateRulePositives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantNoErrString string + }{ + {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"}, + {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"}, + {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"}, + {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"}, + {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"}, + {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"}, + {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"}, + {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateActionNegatives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantErrString string + }{ + {"emptyKind", Action{}, "invalid rule action kind"}, + {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateActionPositives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantNoErrString string + }{ + {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateConditionNegatives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantErrString string + }{ + {"emptyKind", Condition{}, "invalid rule condition kind"}, + {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateConditionPositives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantNoErrString string + }{ + {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 8d0d2dfa5..19483b268 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -163,6 +163,7 @@ type StatementList []struct { func (s StatementList) Prepare(db *sql.DB) (err error) { for _, statement := range s { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL) return } } diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..5124f37e6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. - if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - - _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] + return } // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { return } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 164be6bec..ff939355f 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool prevIDsExist func(string, []int) bool storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() var result []string - for userID := range d.staleUsers { + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err @@ -75,10 +81,18 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() d.staleUsers[userID] = isStale return nil } +func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + return d.staleUsers[userID] +} + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). Does not modify the stream ID for keys. func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { @@ -161,7 +175,7 @@ func TestUpdateHavePrevID(t *testing.T) { if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) } - if db.staleUsers[event.UserID] { + if db.isStale(event.UserID) { t.Errorf("%s incorrectly marked as stale", event.UserID) } } @@ -235,7 +249,7 @@ func TestUpdateNoPrevID(t *testing.T) { }, } // Now we should have a fresh list and the keys and emitted something - if db.staleUsers[event.UserID] { + if db.isStale(event.UserID) { t.Errorf("%s still marked as stale", event.UserID) } if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { @@ -247,3 +261,83 @@ func TestUpdateNoPrevID(t *testing.T) { } } + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + t.Skipf("panic on closed channel on GHA") + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.isStale(userID) { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.isStale(userID) { + t.Errorf("user %s is marked as stale", userID) + } +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 6507505fe..2a65f57e3 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -53,8 +53,7 @@ func Setup( ) { rateLimits := httputil.NewRateLimits(rateLimit) - r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() - v1mux := publicAPIMux.PathPrefix("/v1").Subrouter() + v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, @@ -77,21 +76,18 @@ func Setup( } }) - r0mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) - v1mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) activeRemoteRequests := &types.ActiveRemoteRequests{ MXCToResult: map[string]*types.RemoteRequestResult{}, } downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) - r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed - v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/thumbnail/{serverName}/{mediaId}", + v3mux.Handle("/thumbnail/{serverName}/{mediaId}", makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/q.sqlite b/q.sqlite new file mode 100644 index 000000000..b7d6268e2 Binary files /dev/null and b/q.sqlite differ diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..66e85f2f3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -269,6 +269,7 @@ type QueryAuthChainResponse struct { type QuerySharedUsersRequest struct { UserID string + OtherUserIDs []string ExcludeRoomIDs []string IncludeRoomIDs []string } @@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct { } type QueryCurrentStateRequest struct { - RoomID string + RoomID string + AllowWildcards bool + // State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching + // state events with that event type. StateTuples []gomatrixserverlib.StateKeyTuple } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 012094c62..5491d36b3 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,12 +51,8 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events(event.RoomVersion) - if err != nil { - return err - } - - var ires []InputRoomEvent + outliers := state.Events(event.RoomVersion) + ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4655e92a9..a7da9b06d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -23,6 +23,21 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } +func (p *parsedRespState) Events() []*gomatrixserverlib.Event { + eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) + for i, event := range p.AuthEvents { + eventsByID[event.EventID()] = p.AuthEvents[i] + } + for i, event := range p.StateEvents { + eventsByID[event.EventID()] = p.StateEvents[i] + } + allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + for _, event := range eventsByID { + allEvents = append(allEvents, event) + } + return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) +} + type missingStateReq struct { origin gomatrixserverlib.ServerName db storage.Database @@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() sendOutliers := func(resolvedState *parsedRespState) error { - outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if oerr != nil { - return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) - } - var outlierRoomEvents []api.InputRoomEvent + outliers := resolvedState.Events() + outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if hadEvents[outlier.EventID()] { continue diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..70cc5d62c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) for _, tuple := range req.StateTuples { - ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev + if tuple.StateKey == "*" && req.AllowWildcards { + events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) + if err != nil { + return err + } + for _, e := range events { + res.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = e + } + } else { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } } } return nil @@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser } roomIDs = roomIDs[:j] - users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 685505d52..a2b22b401 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -146,13 +146,14 @@ type Database interface { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. + JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index e98a87b04..33a965457 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -322,13 +323,10 @@ func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectJoinedUsersSetForRooms( ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, + userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { - roomIDarray := make([]int64, len(roomNIDs)) - for i := range roomNIDs { - roomIDarray[i] = int64(roomNIDs[i]) - } stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index bb39d7fbf..eae8cd2cb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -13,7 +13,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -979,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, nil } +// Same as GetStateEvent but returns all matching state events with this event type. Returns no error +// if there are no events with this event type. +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("room %s doesn't exist", roomID) + } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + var eventNIDs []types.EventNID + for _, e := range entries { + if e.EventTypeNID == eventTypeNID { + eventNIDs = append(eventNIDs, e.EventNID) + } + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + // return the events requested + eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, err + } + if len(eventPairs) == 0 { + return nil, nil + } + var result []*gomatrixserverlib.HeaderedEvent + for _, pair := range eventPairs { + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + result = append(result, ev.Headered(roomInfo.RoomVersion)) + } + + return result, nil +} + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { var membershipState tables.MembershipState @@ -1106,13 +1161,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return result, nil } -// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { +// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) + userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + if err != nil { + return nil, err + } + userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap)) + nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap)) + for id, nid := range userNIDsMap { + userNIDs = append(userNIDs, nid) + nidToUserID[nid] = id + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) if err != nil { return nil, err } @@ -1122,13 +1187,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) - if err != nil { - return nil, err - } - if len(nidToUserID) != len(userNIDToCount) { - logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) - } result := make(map[string]int, len(userNIDToCount)) for nid, count := range userNIDToCount { result[nidToUserID[nid]] = count diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index f852422e7..ff291a3a7 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -42,7 +42,8 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -296,18 +297,22 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { + params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) + for _, v := range roomNIDs { + params = append(params, v) } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + for _, v := range userNIDs { + params = append(params, v) + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) var rows *sql.Rows var err error if txn != nil { - rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + rows, err = txn.QueryContext(ctx, query, params...) } else { - rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + rows, err = s.db.QueryContext(ctx, query, params...) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 468a427ff..e814fb0f3 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -128,9 +128,8 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) diff --git a/run-sytest.sh b/run-sytest.sh new file mode 100755 index 000000000..47635fd12 --- /dev/null +++ b/run-sytest.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# +# Runs SyTest either from Docker Hub, or from ../sytest. If it's run +# locally, the Docker image is rebuilt first. +# +# Logs are stored in ../sytestout/logs. + +set -e +set -o pipefail + +main() { + local tag=buster + local base_image=debian:$tag + local runargs=() + + cd "$(dirname "$0")" + + if [ -d ../sytest ]; then + local tmpdir + tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)" + trap "rm -r '$tmpdir'" EXIT + + if [ -z "$DISABLE_BUILDING_SYTEST" ]; then + echo "Re-building ../sytest Docker images..." + + local status + ( + cd ../sytest + + docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" . + docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest . + ) &>"$tmpdir/buildlog" || status=$? + if (( status != 0 )); then + # Docker is very verbose, and we don't really care about + # building SyTest. So we accumulate and only output on + # failure. + cat "$tmpdir/buildlog" >&2 + return $status + fi + fi + + runargs+=( -v "$PWD/../sytest:/sytest:ro" ) + fi + if [ -n "$SYTEST_POSTGRES" ]; then + runargs+=( -e POSTGRES=1 ) + fi + + local sytestout=$PWD/../sytestout + mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg} + docker run \ + --rm \ + --name "sytest-dendrite-${LOGNAME}" \ + -e LOGS_USER=$(id -u) \ + -e LOGS_GROUP=$(id -g) \ + -v "$PWD:/src/:ro" \ + -v "$sytestout/logs:/logs/" \ + -v "$sytestout/cache/go-build:/root/.cache/go-build" \ + -v "$sytestout/cache/go-pkg:/gopath/pkg" \ + "${runargs[@]}" \ + matrixdotorg/sytest-dendrite:latest "$@" +} + +main "$@" diff --git a/setup/base/base.go b/setup/base/base.go index 4ab8769b6..4dab65e51 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -31,6 +31,7 @@ import ( sentryhttp "github.com/getsentry/sentry-go/http" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" @@ -270,6 +271,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { return f } +// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. +func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { + return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) +} + // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() userdb.Database { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 8f7611f0a..6467b7c82 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -205,6 +205,11 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 + pusher_database: + connection_string: file:pushserver.db + max_open_conns: 100 + max_idle_conns: 2 + conn_max_lifetime: -1 tracing: enabled: false jaeger: diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 1cb5eba18..570dc6030 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -13,6 +13,9 @@ type UserAPI struct { // The length of time an OpenID token is condidered valid in milliseconds OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` + // Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED! + PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"` + // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index 5810a2a91..3f07488f9 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -18,7 +18,10 @@ var ( OutputKeyChangeEvent = "OutputKeyChangeEvent" OutputTypingEvent = "OutputTypingEvent" OutputClientData = "OutputClientData" + OutputNotificationData = "OutputNotificationData" OutputReceiptEvent = "OutputReceiptEvent" + OutputStreamEvent = "OutputStreamEvent" + OutputReadUpdate = "OutputReadUpdate" ) var streams = []*nats.StreamConfig{ @@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, + { + Name: OutputNotificationData, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputStreamEvent, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputReadUpdate, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, } diff --git a/setup/monolith.go b/setup/monolith.go index 61125e4a9..7dbd2eeaa 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, - &m.Config.MSCs, + m.FederationAPI, m.UserAPI, m.KeyAPI, + m.ExtPublicRoomsProvider, &m.Config.MSCs, ) federationapi.AddPublicRoutes( ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 6ad79be96..7678dcb86 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events(rc.roomVersion) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") - return - } + eventsInOrder := respState.Events(rc.roomVersion) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 7be453554..7ab50c32e 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -18,17 +18,19 @@ package msc2946 import ( "context" "encoding/json" - "fmt" "net/http" + "net/url" + "sort" + "strconv" "strings" "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" - chttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -39,42 +41,27 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) -// Defaults sets the request defaults -func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { - r.Limit = 2000 - r.MaxRoomsPerSpace = -1 +type MSC2946ClientResponse struct { + Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, - fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, + fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) - if err != nil { - return fmt.Errorf("cannot enable MSC2946: %w", err) - } - hooks.Enable() - hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - hookErr := db.StoreReference(context.Background(), he) - if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( - "failed to StoreReference", - ) - } - }) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) + base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", - httputil.MakeAuthAPI("spaces", userAPI, base.Cfg.Global.UserConsentOptions, false, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), - ).Methods(http.MethodPost, http.MethodOptions) - - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI( + fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( req, time.Now(), base.Cfg.Global.ServerName, keyRing, @@ -88,252 +75,308 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, - )).Methods(http.MethodPost, http.MethodOptions) + ) + base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, + cache caching.SpaceSummaryRoomsCache, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if err := json.Unmarshal(fedReq.Content(), &r); err != nil { + u, err := url.Parse(fedReq.RequestURI()) + if err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + Code: 400, + JSON: jsonerror.InvalidParam("bad request uri"), } } - w := walker{ - req: &r, - rootRoomID: roomID, - serverName: fedReq.Origin(), - thisServer: thisServer, - ctx: ctx, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + w := walker{ + rootRoomID: roomID, + serverName: fedReq.Origin(), + thisServer: thisServer, + ctx: ctx, + cache: cache, + suggestedOnly: u.Query().Get("suggested_only") == "true", + limit: 1000, + // The main difference is that it does not recurse into spaces and does not support pagination. + // This is somewhat equivalent to a Client-Server request with a max_depth=1. + maxDepth: 1, + + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } + return w.walk() } func spacesHandler( - db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, + fsAPI fs.FederationInternalAPI, + cache caching.SpaceSummaryRoomsCache, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } roomID := params["roomID"] - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } w := walker{ - req: &r, - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), + cache: cache, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } + return w.walk() } } +type paginationInfo struct { + processed set + unvisited []roomVisit +} + type walker struct { - req *gomatrixserverlib.MSC2946SpacesRequest - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - db Database - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + cache caching.SpaceSummaryRoomsCache + suggestedOnly bool + limit int + maxDepth int + paginationToken string - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex + paginationCache map[string]paginationInfo + mu sync.Mutex } -func (w *walker) roomIsExcluded(roomID string) bool { - for _, exclRoom := range w.req.ExcludeRooms { - if exclRoom == roomID { - return true +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, + } + tok := uuid.NewString() + return tok, p +} + +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { + w.mu.Lock() + defer w.mu.Unlock() + p := w.paginationCache[paginationToken] + return &p +} + +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache +} + +type roomVisit struct { + roomID string + depth int + vias []string // vias to query this room by +} + +func (w *walker) walk() util.JSONResponse { + if !w.authorised(w.rootRoomID) { + if w.caller != nil { + // CS API format + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("room is unknown/forbidden"), + } + } else { + // SS API format + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } } } - return false -} -func (w *walker) callerID() string { - if w.caller != nil { - return w.caller.UserID + "|" + w.caller.ID + var discoveredRooms []gomatrixserverlib.MSC2946Room + + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) } - return string(w.serverName) -} -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false - } - return m[id] -} + processed := cache.processed + unvisited := cache.unvisited -func (w *walker) markSent(id string) { - w.mu.Lock() - defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m -} - -func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { - var res gomatrixserverlib.MSC2946SpacesResponse - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms - unvisited := []string{w.rootRoomID} - processed := make(set) + // Depth first -> stack data structure for len(unvisited) > 0 { - roomID := unvisited[0] - unvisited = unvisited[1:] - // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] || roomID == "" { + if len(discoveredRooms) >= w.limit { + break + } + + // pop the stack + rv := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + // If this room has already been processed, skip. + // If this room exceeds the specified depth, skip. + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } + // Mark this room as processed. - processed[roomID] = true + processed.set(rv.roomID) + + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredRooms []gomatrixserverlib.MSC2946Room - var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(roomID) && w.authorised(roomID) { - // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get - // all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) - // this room. This requires servers to store reverse lookups. - events, err := w.references(roomID) + if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") continue } - discoveredEvents = events + discoveredChildEvents = events - pubRoom := w.publicRoomsChunk(roomID) - roomType := "" - create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } + pubRoom := w.publicRoomsChunk(rv.roomID) - // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - NumRefs: len(discoveredEvents), - RoomType: roomType, + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, }) } else { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) - fedRes, err := w.federatedRoomInfo(roomID) + fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces") continue } if fedRes != nil { - discoveredRooms = fedRes.Rooms - discoveredEvents = fedRes.Events + discoveredChildEvents = fedRes.Room.ChildrenState + discoveredRooms = append(discoveredRooms, fedRes.Room) + if len(fedRes.Children) > 0 { + discoveredRooms = append(discoveredRooms, fedRes.Children...) + } + // mark this room as a space room as the federated server responded. + // we need to do this so we add the children of this room to the unvisited stack + // as these children may be rooms we do know about. + roomType = ConstCreateEventContentValueSpace } } - // If this room has not ever been in `rooms` (across multiple requests), send it now - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { - res.Rooms = append(res.Rooms, room) - w.markSent(room.RoomID) - } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue } - uniqueRooms := make(set) - - // If this is the root room from the original request, insert all these events into `events` if - // they haven't been added before (across multiple requests). - if w.rootRoomID == roomID { - for _, ev := range discoveredEvents { - if !w.alreadySent(eventKey(&ev)) { - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - } - } - } else { - // Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either - // are exceeded, stop adding events. If the event has already been added, do not add it again. - numAdded := 0 - for _, ev := range discoveredEvents { - if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { - break - } - if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { - break - } - if w.alreadySent(eventKey(&ev)) { - continue - } - // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still - // want to catch arrows which point to excluded rooms. - if w.roomIsExcluded(ev.RoomID) { - continue - } - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - // we don't distinguish between child state events and parent state events for the purposes of - // max_rooms_per_space, maybe we should? - numAdded++ - } - } - - // For each referenced room ID in the events being returned to the caller (both parent and child) + // For each referenced room ID in the child events being returned to the caller // add the room ID to the queue of unvisited rooms. Loop from the beginning. - for roomID := range uniqueRooms { - unvisited = append(unvisited, roomID) + // We need to invert the order here because the child events are lo->hi on the timestamp, + // so we need to ensure we pop in the same lo->hi order, which won't be the case if we + // insert the highest timestamp last in a stack. + for i := len(discoveredChildEvents) - 1; i >= 0; i-- { + spaceContent := struct { + Via []string `json:"via"` + }{} + ev := discoveredChildEvents[i] + _ = json.Unmarshal(ev.Content, &spaceContent) + unvisited = append(unvisited, roomVisit{ + roomID: ev.StateKey, + depth: rv.depth + 1, + vias: spaceContent.Via, + }) } } - return &res + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + + if w.caller != nil { + // return CS API format + return util.JSONResponse{ + Code: 200, + JSON: MSC2946ClientResponse{ + Rooms: discoveredRooms, + NextBatch: w.paginationToken, + }, + } + } + // return SS API format + // the first discovered room will be the room asked for, and subsequent ones the depth=1 children + if len(discoveredRooms) == 0 { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: gomatrixserverlib.MSC2946SpacesResponse{ + Room: discoveredRooms[0], + Children: discoveredRooms[1:], + }, + } } func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { @@ -366,46 +409,41 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { +func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { // only do federated requests for client requests if w.caller == nil { return nil, nil } - // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) + resp, ok := w.cache.GetSpaceSummary(roomID) + if ok { + util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID) + return &resp, nil } - vias := make(set) - for _, ev := range events { - if ev.StateKeyEquals(roomID) { - // event points at this room, extract vias - content := struct { - Vias []string `json:"via"` - }{} - if err = json.Unmarshal(ev.Content(), &content); err != nil { - continue // silently ignore corrupted state events - } - for _, v := range content.Vias { - vias[v] = true - } - } - } - util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias) + util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers - for serverName := range vias { + for _, serverName := range vias { if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ - Limit: w.req.Limit, - MaxRoomsPerSpace: w.req.MaxRoomsPerSpace, - }) + res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue } + // ensure nil slices are empty as we send this to the client sometimes + if res.Room.ChildrenState == nil { + res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + for i := 0; i < len(res.Children); i++ { + child := res.Children[i] + if child.ChildrenState == nil { + child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + res.Children[i] = child + } + w.cache.StoreSpaceSummary(roomID, res) + return &res, nil } return nil, nil @@ -501,7 +539,7 @@ func (w *walker) authorisedUser(roomID string) bool { hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } @@ -514,40 +552,85 @@ func (w *walker) authorisedUser(roomID string) bool { return false } -// references returns all references pointing to or from this room. -func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) +// references returns all child references pointing to or from this room. +func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { + createTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomCreate, + StateKey: "", + } + var res roomserver.QueryCurrentStateResponse + err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{ + RoomID: roomID, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + createTuple, { + EventType: ConstSpaceChildEventType, + StateKey: "*", + }, + }, + }, &res) if err != nil { return nil, err } - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) - for _, ev := range events { + + // don't return any child refs if the room is not a space room + if res.StateEvents[createTuple] != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return []gomatrixserverlib.MSC2946StrippedEvent{}, nil + } + } + delete(res.StateEvents, createTuple) + + el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + for _, ev := range res.StateEvents { + content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) - if gjson.GetBytes(ev.Content(), "via").Exists() { + if content.Get("via").Exists() { strip := stripped(ev.Event) if strip == nil { continue } + // if suggested only and this child isn't suggested, skip it. + // if suggested only = false we include everything so don't need to check the content. + if w.suggestedOnly && !content.Get("suggested").Bool() { + continue + } el = append(el, *strip) } } + // sort by origin_server_ts as per MSC2946 + sort.Slice(el, func(i, j int) bool { + return el[i].OriginServerTS < el[j].OriginServerTS + }) + return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } return &gomatrixserverlib.MSC2946StrippedEvent{ - Type: ev.Type(), - StateKey: *ev.StateKey(), - Content: ev.Content(), - Sender: ev.Sender(), - RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: *ev.StateKey(), + Content: ev.Content(), + Sender: ev.Sender(), + RoomID: ev.RoomID(), + OriginServerTS: ev.OriginServerTS(), } } @@ -567,3 +650,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string { } return "" } + +func parseInt(intstr string, defaultVal int) int { + i, err := strconv.ParseInt(intstr, 10, 32) + if err != nil { + return defaultVal + } + return int(i) +} diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go deleted file mode 100644 index e8066c34d..000000000 --- a/setup/mscs/msc2946/msc2946_test.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946_test - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "testing" - "time" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/hooks" - "github.com/matrix-org/dendrite/internal/httputil" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs/msc2946" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - client = &http.Client{ - Timeout: 10 * time.Second, - } - roomVer = gomatrixserverlib.RoomVersionV6 -) - -// Basic sanity check of MSC2946 logic. Tests a single room with a few state events -// and a bit of recursion to subspaces. Makes a graph like: -// Root -// ____|_____ -// | | | -// R1 R2 S1 -// |_________ -// | | | -// R3 R4 S2 -// | <-- this link is just a parent, not a child -// R5 -// -// Alice is not joined to R4, but R4 is "world_readable". -func TestMSC2946(t *testing.T) { - alice := "@alice:localhost" - // give access token to alice - nopUserAPI := &testUserAPI{ - accessTokens: make(map[string]userapi.Device), - } - nopUserAPI.accessTokens["alice"] = userapi.Device{ - AccessToken: "alice", - DisplayName: "Alice", - UserID: alice, - } - rootSpace := "!rootspace:localhost" - subSpaceS1 := "!subspaceS1:localhost" - subSpaceS2 := "!subspaceS2:localhost" - room1 := "!room1:localhost" - room2 := "!room2:localhost" - room3 := "!room3:localhost" - room4 := "!room4:localhost" - empty := "" - room5 := "!room5:localhost" - allRooms := []string{ - rootSpace, subSpaceS1, subSpaceS2, - room1, room2, room3, room4, room5, - } - rootToR1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToR2 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToS1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR4 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room4, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToS2 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // This is a parent link only - s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: room5, - Sender: alice, - Type: msc2946.ConstSpaceParentEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // history visibility for R4 - r4HisVis := mustCreateEvent(t, fledglingEvent{ - RoomID: room4, - Sender: "@someone:localhost", - Type: gomatrixserverlib.MRoomHistoryVisibility, - StateKey: &empty, - Content: map[string]interface{}{ - "history_visibility": "world_readable", - }, - }) - var joinEvents []*gomatrixserverlib.HeaderedEvent - for _, roomID := range allRooms { - if roomID == room4 { - continue // not joined to that room - } - joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{ - RoomID: roomID, - Sender: alice, - StateKey: &alice, - Type: gomatrixserverlib.MRoomMember, - Content: map[string]interface{}{ - "membership": "join", - }, - })) - } - roomNameTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.name", - StateKey: "", - } - hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.history_visibility", - StateKey: "", - } - nopRsAPI := &testRoomserverAPI{ - joinEvents: joinEvents, - events: map[string]*gomatrixserverlib.HeaderedEvent{ - rootToR1.EventID(): rootToR1, - rootToR2.EventID(): rootToR2, - rootToS1.EventID(): rootToS1, - s1ToR3.EventID(): s1ToR3, - s1ToR4.EventID(): s1ToR4, - s1ToS2.EventID(): s1ToS2, - s2ToR5.EventID(): s2ToR5, - r4HisVis.EventID(): r4HisVis, - }, - pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ - rootSpace: { - roomNameTuple: "Root", - hisVisTuple: "shared", - }, - subSpaceS1: { - roomNameTuple: "Sub-Space 1", - hisVisTuple: "joined", - }, - subSpaceS2: { - roomNameTuple: "Sub-Space 2", - hisVisTuple: "shared", - }, - room1: { - hisVisTuple: "joined", - }, - room2: { - hisVisTuple: "joined", - }, - room3: { - hisVisTuple: "joined", - }, - room4: { - hisVisTuple: "world_readable", - }, - room5: { - hisVisTuple: "joined", - }, - }, - } - allEvents := []*gomatrixserverlib.HeaderedEvent{ - rootToR1, rootToR2, rootToS1, - s1ToR3, s1ToR4, s1ToS2, - s2ToR5, r4HisVis, - } - allEvents = append(allEvents, joinEvents...) - router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents) - cancel := runServer(t, router) - defer cancel() - - t.Run("returns no events for unknown rooms", func(t *testing.T) { - res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{})) - if len(res.Events) > 0 { - t.Errorf("got %d events, want 0", len(res.Events)) - } - if len(res.Rooms) > 0 { - t.Errorf("got %d rooms, want 0", len(res.Rooms)) - } - }) - t.Run("returns the entire graph", func(t *testing.T) { - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 7 { - t.Errorf("got %d events, want 7", len(res.Events)) - } - if len(res.Rooms) != len(allRooms) { - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) - } - }) - t.Run("can update the graph", func(t *testing.T) { - // remove R3 from the graph - rmS1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{}, // redacted - }) - nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3 - hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3) - - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 6 { // one less since we don't return redacted events - t.Errorf("got %d events, want 6", len(res.Events)) - } - if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3 - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1) - } - }) -} - -func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest { - t.Helper() - b, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Failed to marshal request: %s", err) - } - var r gomatrixserverlib.MSC2946SpacesRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Fatalf("Failed to unmarshal request: %s", err) - } - return &r -} - -func runServer(t *testing.T, router *mux.Router) func() { - t.Helper() - externalServ := &http.Server{ - Addr: string(":8010"), - WriteTimeout: 60 * time.Second, - Handler: router, - } - go func() { - externalServ.ListenAndServe() - }() - // wait to listen on the port - time.Sleep(500 * time.Millisecond) - return func() { - externalServ.Shutdown(context.TODO()) - } -} - -func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse { - t.Helper() - var r gomatrixserverlib.MSC2946SpacesRequest - msc2946.Defaults(&r) - data, err := json.Marshal(req) - if err != nil { - t.Fatalf("failed to marshal request: %s", err) - } - httpReq, err := http.NewRequest( - "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces", - bytes.NewBuffer(data), - ) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if err != nil { - t.Fatalf("failed to prepare request: %s", err) - } - res, err := client.Do(httpReq) - if err != nil { - t.Fatalf("failed to do request: %s", err) - } - if res.StatusCode != expectCode { - body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) - } - if res.StatusCode == 200 { - var result gomatrixserverlib.MSC2946SpacesResponse - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("response 200 OK but failed to read response body: %s", err) - } - t.Logf("Body: %s", string(body)) - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) - } - return &result - } - return nil -} - -type testUserAPI struct { - userapi.UserInternalAPITrace - accessTokens map[string]userapi.Device -} - -func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { - dev, ok := u.accessTokens[req.AccessToken] - if !ok { - res.Err = "unknown token" - return nil - } - res.Device = &dev - return nil -} - -type testRoomserverAPI struct { - // use a trace API as it implements method stubs so we don't need to have them here. - // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace - joinEvents []*gomatrixserverlib.HeaderedEvent - events map[string]*gomatrixserverlib.HeaderedEvent - pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error { - res.IsInRoom = true - res.RoomExists = true - return nil -} - -func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, roomID := range req.RoomIDs { - pubRoomData, ok := r.pubRoomState[roomID] - if ok { - res.Rooms[roomID] = pubRoomData - } - } - return nil -} - -func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - checkEvent := func(he *gomatrixserverlib.HeaderedEvent) { - if he.RoomID() != req.RoomID { - return - } - if he.StateKey() == nil { - return - } - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: he.Type(), - StateKey: *he.StateKey(), - } - for _, t := range req.StateTuples { - if t == tuple { - res.StateEvents[t] = he - } - } - } - for _, he := range r.joinEvents { - checkEvent(he) - } - for _, he := range r.events { - checkEvent(he) - } - return nil -} - -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { - t.Helper() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = "localhost" - cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db" - cfg.MSCs.MSCs = []string{"msc2946"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - - err := msc2946.Enable(base, rsAPI, userAPI, nil, nil) - if err != nil { - t.Fatalf("failed to enable MSC2946: %s", err) - } - for _, ev := range events { - hooks.Run(hooks.KindNewEventPersisted, ev) - } - return base.PublicClientAPIMux -} - -type fledglingEvent struct { - Type string - StateKey *string - Content interface{} - Sender string - RoomID string -} - -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { - t.Helper() - seed := make([]byte, ed25519.SeedSize) // zero seed - key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, - Depth: 999, - Type: ev.Type, - StateKey: ev.StateKey, - RoomID: ev.RoomID, - } - err := eb.SetContent(ev.Content) - if err != nil { - t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) - } - // make sure the origin_server_ts changes so we can test recency - time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) - if err != nil { - t.Fatalf("mustCreateEvent: failed to sign event: %s", err) - } - h := signedEvent.Headered(roomVer) - return h -} diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go deleted file mode 100644 index 20db18594..000000000 --- a/setup/mscs/msc2946/storage.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - relTypes = map[string]int{ - ConstSpaceChildEventType: 1, - ConstSpaceParentEventType: 2, - } -) - -type Database interface { - // StoreReference persists a child or parent space mapping. - StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) -} - -type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt -} - -// NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) - } - return newSQLiteDatabase(dbOpts) -} - -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - target := SpaceTarget(he) - if target == "" { - return nil // malformed event - } - relType := relTypes[he.Type()] - _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) - return err -} - -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") - refs := make([]*gomatrixserverlib.HeaderedEvent, 0) - for rows.Next() { - var roomVer string - var jsonBytes []byte - if err := rows.Scan(&roomVer, &jsonBytes); err != nil { - return nil, err - } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) - if err != nil { - return nil, err - } - he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) - refs = append(refs, he) - } - return refs, nil -} - -// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent -// depending on the event type. -func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { - if he.StateKey() == nil { - return "" // no-op - } - switch he.Type() { - case ConstSpaceParentEventType: - return *he.StateKey() - case ConstSpaceChildEventType: - return *he.StateKey() - } - return "" -} diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 10df6a15d..35b7bba3b 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er case "msc2836": return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2946": - return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing) + return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi default: diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index c3650085f..fcb7b5b1c 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -16,7 +16,9 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/internal/eventutil" @@ -24,21 +26,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -49,15 +56,18 @@ func NewOutputClientDataConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), - durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -100,8 +110,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) }).Panicf("could not save account data") } + if err = s.sendReadUpdate(ctx, userID, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": userID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) return true } + +func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { + if output.Type != "m.fully_read" || output.ReadMarker == nil { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + var fullyReadPos types.StreamPosition + if output.ReadMarker.Read != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if output.ReadMarker.FullyRead != "" { + if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) + } + } + if readPos > 0 || fullyReadPos > 0 { + if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 392840ece..4e4c61c67 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -16,7 +16,9 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/eduserver/api" @@ -24,21 +26,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -50,15 +57,18 @@ func NewOutputReceiptEventConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -92,8 +102,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms return true } + if err = s.sendReadUpdate(ctx, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": output.UserID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } + +func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error { + if output.Type != "m.read" { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + if output.EventID != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if readPos > 0 { + if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 15485bb35..159657f9f 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier + producer *producers.UserAPIStreamEventProducer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.RoomserverInternalAPI, + producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, + producer: producer, } } @@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } + if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { + log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) + sentry.CaptureException(err) + return err + } + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) sentry.CaptureException(err) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go new file mode 100644 index 000000000..a3b2dd53d --- /dev/null +++ b/syncapi/consumers/userapi.go @@ -0,0 +1,110 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// OutputNotificationDataConsumer consumes events that originated in +// the Push server. +type OutputNotificationDataConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider +} + +// NewOutputNotificationDataConsumer creates a new consumer. Call +// Start() to begin consuming. +func NewOutputNotificationDataConsumer( + process *process.ProcessContext, + cfg *config.SyncAPI, + js nats.JetStreamContext, + store storage.Database, + notifier *notifier.Notifier, + stream types.StreamProvider, +) *OutputNotificationDataConsumer { + s := &OutputNotificationDataConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + db: store, + notifier: notifier, + stream: stream, + } + return s +} + +// Start starts consumption. +func (s *OutputNotificationDataConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called when the Sync server receives a new event from +// the push server. It is not safe for this function to be called from +// multiple goroutines, or else the sync stream position may race and +// be incorrectly calculated. +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + userID := string(msg.Header.Get(jetstream.UserID)) + + // Parse out the event JSON + var data eventutil.NotificationData + if err := json.Unmarshal(msg.Data, &data); err != nil { + sentry.CaptureException(err) + log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure") + return true + } + + streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).WithError(err).Error("Could not save notification counts") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos}) + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + "streamPos": streamPos, + }).Trace("Received notification data from user API") + + return true +} diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 37a9e2d39..dc4acd8da 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -82,7 +82,16 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return to, hasNew, nil } - // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. + + // Work out which user IDs we care about — that includes those in the original request, + // the response from QueryKeyChanges (which includes ALL users who have changed keys) + // as well as every user who has a join or leave event in the current sync response. We + // will request information about which rooms these users are joined to, so that we can + // see if we still share any rooms with them. + joinUserIDs, leaveUserIDs := membershipEvents(res) + queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) + queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) + queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( @@ -100,9 +109,8 @@ func DeviceListCatchup( userSet[userID] = true } } - // if the response has any join/leave events, add them now. + // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. - joinUserIDs, leaveUserIDs := membershipEvents(res) for _, userID := range joinUserIDs { if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) @@ -213,7 +221,8 @@ func filterSharedUsers( var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, + UserID: userID, + OtherUserIDs: usersWithChangedKeys, }, &sharedUsersRes) if err != nil { // default to all users so we do needless queries rather than miss some important device update diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index d853cc0e4..6a641e6f8 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite( n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) } +func (n *Notifier) OnNewNotificationData( + userID string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{userID}, nil, n.currPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index c6d3df7ee..60403d5d5 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) { go func() { pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { - t.Errorf("TestNewInviteEventForUser error: %w", err) + t.Errorf("TestNewInviteEventForUser error: %v", err) } mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go new file mode 100644 index 000000000..d56cab776 --- /dev/null +++ b/syncapi/producers/userapi_readupdate.go @@ -0,0 +1,62 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIReadProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + m.Header.Set(jetstream.RoomID, roomID) + + data := types.ReadUpdate{ + UserID: userID, + RoomID: roomID, + Read: readPos, + FullyRead: fullyReadPos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "read_pos": readPos, + "fully_read_pos": fullyReadPos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go new file mode 100644 index 000000000..2bbd19c0b --- /dev/null +++ b/syncapi/producers/userapi_streamevent.go @@ -0,0 +1,60 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIStreamEventProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.RoomID, roomID) + + data := types.StreamedEvent{ + Event: event, + StreamPosition: pos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "room_id": roomID, + "event_id": event.EventID(), + "event_type": event.Type(), + "stream_pos": pos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ef7efa2b0..d07fc0c64 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -17,6 +17,7 @@ package routing import ( "database/sql" "encoding/json" + "fmt" "net/http" "strconv" @@ -44,7 +45,7 @@ func Context( syncDB storage.Database, roomID, eventID string, ) util.JSONResponse { - filter, err := parseContextParams(req) + filter, err := parseRoomEventFilter(req) if err != nil { errMsg := "" switch err.(type) { @@ -102,6 +103,12 @@ func Context( id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) if err != nil { + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)), + } + } logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") return jsonerror.InternalServerError() } @@ -164,7 +171,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter return newState } -func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { +func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { // Default room filter filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} diff --git a/syncapi/routing/context_test.go b/syncapi/routing/context_test.go index 1b430d83a..e79a5d5f1 100644 --- a/syncapi/routing/context_test.go +++ b/syncapi/routing/context_test.go @@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotFilter, err := parseContextParams(tt.req) + gotFilter, err := parseRoomEventFilter(tt.req) if (err != nil) != tt.wantErr { - t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(gotFilter, tt.wantFilter) { - t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter) + t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter) } }) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9bb8c6d24..7cd54eef0 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -19,7 +19,6 @@ import ( "fmt" "net/http" "sort" - "strconv" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" @@ -45,8 +44,8 @@ type messagesReq struct { fromStream *types.StreamingToken device *userapi.Device wasToProvided bool - limit int backwardOrdering bool + filter *gomatrixserverlib.RoomEventFilter } type messagesResp struct { @@ -54,10 +53,9 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + State []gomatrixserverlib.ClientEvent `json:"state"` } -const defaultMessagesLimit = 10 - // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages @@ -83,6 +81,14 @@ func OnIncomingMessagesRequest( } } + filter, err := parseRoomEventFilter(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unable to parse filter"), + } + } + // Extract parameters from the request's URL. // Pagination tokens. var fromStream *types.StreamingToken @@ -143,18 +149,6 @@ func OnIncomingMessagesRequest( wasToProvided = false } - // Maximum number of events to return; defaults to 10. - limit := defaultMessagesLimit - if len(req.URL.Query().Get("limit")) > 0 { - limit, err = strconv.Atoi(req.URL.Query().Get("limit")) - - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), - } - } - } // TODO: Implement filtering (#587) // Check the room ID's format. @@ -176,7 +170,7 @@ func OnIncomingMessagesRequest( to: &to, fromStream: fromStream, wasToProvided: wasToProvided, - limit: limit, + filter: filter, backwardOrdering: backwardOrdering, device: device, } @@ -187,10 +181,27 @@ func OnIncomingMessagesRequest( return jsonerror.InternalServerError() } + // at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set + state := []gomatrixserverlib.ClientEvent{} + if filter.LazyLoadMembers { + memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) + for _, evt := range clientEvents { + memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user") + continue + } + memberShipToUser[evt.Sender] = memberShip + } + for _, evt := range memberShipToUser { + state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) + } + } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), - "limit": limit, + "limit": filter.Limit, "backwards": backwardOrdering, "return_start": start.String(), "return_end": end.String(), @@ -200,6 +211,7 @@ func OnIncomingMessagesRequest( Chunk: clientEvents, Start: start.String(), End: end.String(), + State: state, } if emptyFromSupplied { res.StartStream = fromStream.String() @@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := gomatrixserverlib.DefaultRoomEventFilter() - eventFilter.Limit = r.limit + eventFilter := r.filter // Retrieve the events from the local database. var streamEvents []types.StreamEvent if r.fromStream != nil { toStream := r.to.StreamToken() streamEvents, err = r.db.GetEventsInStreamingRange( - r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering, + r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering, ) } else { streamEvents, err = r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, ) } if err != nil { @@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { // If so, retrieve as much events as needed through backfilling. - events, err = r.backfill(r.roomID, backwardExtremities, r.limit) + events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit) if err != nil { return } @@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent events []*gomatrixserverlib.HeaderedEvent, err error, ) { // Check if we have enough events. - isSetLargeEnough := len(streamEvents) >= r.limit + isSetLargeEnough := len(streamEvents) >= r.filter.Limit if !isSetLargeEnough { // it might be fine we don't have up to 'limit' events, let's find out if r.backwardOrdering { @@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { var pdus []*gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. - pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) + pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents)) if err != nil { return } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 126bc8658..e44766338 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" @@ -31,6 +32,7 @@ type Database interface { MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) + MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) @@ -138,6 +140,12 @@ type Database interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + + // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. + GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go new file mode 100644 index 000000000..f3fc4451f --- /dev/null +++ b/syncapi/storage/postgres/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index f93081e1a..474d0c020 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { - lastPos := streamPos + var lastPos types.StreamPosition rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 6f4e7749d..60fe5b54d 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + notificationData, err := NewPostgresNotificationDataTable(d.db) + if err != nil { + return nil, err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -110,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return &d, nil } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 819851b33..87d7c6df7 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -48,6 +48,7 @@ type Database struct { Filter tables.Filter Receipts tables.Receipts Memberships tables.Memberships + NotificationData tables.NotificationData } func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { @@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S return types.StreamPosition(id), nil } +func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.NotificationData.SelectMaxID(ctx) + if err != nil { + return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) + } + return types.StreamPosition(id), nil +} + func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) } @@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream return receipts, err } +func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount) + return err + }) + return +} + +func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) +} + func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go new file mode 100644 index 000000000..4b3f074db --- /dev/null +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id INTEGER PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 581ee6928..1b256f91a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -62,16 +62,19 @@ const selectEventsSQL = "" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectRecentEventsForSyncSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectMaxEventIDSQL = "" + @@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const deleteEventsForRoomSQL = "" + @@ -95,10 +99,12 @@ const selectContextEventSQL = "" + const selectContextBeforeEventSQL = "" + "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectContextAfterEventSQL = "" + "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters type outputRoomEventsStatements struct { diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 6b39ee879..9111a39f6 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -101,7 +101,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) - lastPos := streamPos + var lastPos types.StreamPosition params := make([]interface{}, len(roomIDs)+1) params[0] = streamPos for k, v := range roomIDs { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 706d43f81..f5ae9fdd7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er if err != nil { return err } + notificationData, err := NewSqliteNotificationDataTable(d.db) + if err != nil { + return err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -120,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return nil } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1d807ee6b..1ebb42651 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -19,6 +19,7 @@ import ( "database/sql" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -171,3 +172,9 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) } + +type NotificationData interface { + UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectMaxID(ctx context.Context) (int64, error) +} diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go new file mode 100644 index 000000000..8ba9e07ca --- /dev/null +++ b/syncapi/streams/stream_notificationdata.go @@ -0,0 +1,55 @@ +package streams + +import ( + "context" + + "github.com/matrix-org/dendrite/syncapi/types" +) + +type NotificationDataStreamProvider struct { + StreamProvider +} + +func (p *NotificationDataStreamProvider) Setup() { + p.StreamProvider.Setup() + + id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + if err != nil { + panic(err) + } + p.latest = id +} + +func (p *NotificationDataStreamProvider) CompleteSync( + ctx context.Context, + req *types.SyncRequest, +) types.StreamPosition { + return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) +} + +func (p *NotificationDataStreamProvider) IncrementalSync( + ctx context.Context, + req *types.SyncRequest, + from, to types.StreamPosition, +) types.StreamPosition { + // We want counts for all possible rooms, so always start from zero. + countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + if err != nil { + req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + return from + } + + // We're merely decorating existing rooms. Note that the Join map + // values are not pointers. + for roomID, jr := range req.Response.Rooms.Join { + counts := countsByRoom[roomID] + if counts == nil { + continue + } + + jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount + jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + req.Response.Rooms.Join[roomID] = jr + } + return to +} diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index cccadb525..35ffd3a1e 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -63,7 +63,6 @@ func (p *ReceiptStreamProvider) IncrementalSync( if existing, ok := req.Response.Rooms.Join[roomID]; ok { jr = existing } - var ok bool ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MReceipt, @@ -71,8 +70,8 @@ func (p *ReceiptStreamProvider) IncrementalSync( } content := make(map[string]eduAPI.ReceiptMRead) for _, receipt := range receipts { - var read eduAPI.ReceiptMRead - if read, ok = content[receipt.EventID]; !ok { + read, ok := content[receipt.EventID] + if !ok { read = eduAPI.ReceiptMRead{ User: make(map[string]eduAPI.ReceiptTS), } diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 6b02c75ea..17951acb4 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -12,13 +12,14 @@ import ( ) type Streams struct { - PDUStreamProvider types.StreamProvider - TypingStreamProvider types.StreamProvider - ReceiptStreamProvider types.StreamProvider - InviteStreamProvider types.StreamProvider - SendToDeviceStreamProvider types.StreamProvider - AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.StreamProvider + PDUStreamProvider types.StreamProvider + TypingStreamProvider types.StreamProvider + ReceiptStreamProvider types.StreamProvider + InviteStreamProvider types.StreamProvider + SendToDeviceStreamProvider types.StreamProvider + AccountDataStreamProvider types.StreamProvider + DeviceListStreamProvider types.StreamProvider + NotificationDataStreamProvider types.StreamProvider } func NewSyncStreamProviders( @@ -47,6 +48,9 @@ func NewSyncStreamProviders( StreamProvider: StreamProvider{DB: d}, userAPI: userAPI, }, + NotificationDataStreamProvider: &NotificationDataStreamProvider{ + StreamProvider: StreamProvider{DB: d}, + }, DeviceListStreamProvider: &DeviceListStreamProvider{ StreamProvider: StreamProvider{DB: d}, rsAPI: rsAPI, @@ -60,6 +64,7 @@ func NewSyncStreamProviders( streams.InviteStreamProvider.Setup() streams.SendToDeviceStreamProvider.Setup() streams.AccountDataStreamProvider.Setup() + streams.NotificationDataStreamProvider.Setup() streams.DeviceListStreamProvider.Setup() return streams @@ -67,12 +72,13 @@ func NewSyncStreamProviders( func (s *Streams) Latest(ctx context.Context) types.StreamingToken { return types.StreamingToken{ - PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), - TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx), - InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), - SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), - AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), - DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), + PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), + TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), + InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), + SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), + AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), + NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx), + DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), } } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index ca35951a0..2c9920d18 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -189,7 +189,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) } } else { - syncReq.Log.Debugln("Responding to sync immediately") + syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") } if syncReq.Since.IsEmpty() { @@ -213,6 +213,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( syncReq.Context, syncReq, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( syncReq.Context, syncReq, ), @@ -244,6 +247,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Context, syncReq, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( syncReq.Context, syncReq, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 72462459c..cb9890ff7 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" @@ -64,6 +65,18 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) + userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + } + + userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + } + + _ = userAPIReadUpdateProducer + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), js, keyAPI, rsAPI, syncDB, notifier, @@ -75,7 +88,7 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( process, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, + streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -83,11 +96,19 @@ func AddPublicRoutes( clientConsumer := consumers.NewOutputClientDataConsumer( process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, + userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } + notificationConsumer := consumers.NewOutputNotificationDataConsumer( + process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, + ) + if err = notificationConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start notification data consumer") + } + typingConsumer := consumers.NewOutputTypingEventConsumer( process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider, ) @@ -104,6 +125,7 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, + userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c2e8ed01c..4150e6c98 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -95,13 +95,14 @@ const ( ) type StreamingToken struct { - PDUPosition StreamPosition - TypingPosition StreamPosition - ReceiptPosition StreamPosition - SendToDevicePosition StreamPosition - InvitePosition StreamPosition - AccountDataPosition StreamPosition - DeviceListPosition StreamPosition + PDUPosition StreamPosition + TypingPosition StreamPosition + ReceiptPosition StreamPosition + SendToDevicePosition StreamPosition + InvitePosition StreamPosition + AccountDataPosition StreamPosition + DeviceListPosition StreamPosition + NotificationDataPosition StreamPosition } // This will be used as a fallback by json.Marshal. @@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d_%d_%d_%d", + "s%d_%d_%d_%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, - t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, + t.InvitePosition, t.AccountDataPosition, + t.DeviceListPosition, t.NotificationDataPosition, ) return posStr } @@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.DeviceListPosition > other.DeviceListPosition: return true + case t.NotificationDataPosition > other.NotificationDataPosition: + return true } return false } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0 } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) { if other.DeviceListPosition > t.DeviceListPosition { t.DeviceListPosition = other.DeviceListPosition } + if other.NotificationDataPosition > t.NotificationDataPosition { + t.NotificationDataPosition = other.NotificationDataPosition + } } type TopologyToken struct { @@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions tok = strings.Split(tok, ".")[0] parts := strings.Split(tok[1:], "_") - var positions [7]StreamPosition + var positions [8]StreamPosition for i, p := range parts { if i >= len(positions) { break @@ -291,13 +298,14 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { positions[i] = StreamPosition(pos) } token = StreamingToken{ - PDUPosition: positions[0], - TypingPosition: positions[1], - ReceiptPosition: positions[2], - SendToDevicePosition: positions[3], - InvitePosition: positions[4], - AccountDataPosition: positions[5], - DeviceListPosition: positions[6], + PDUPosition: positions[0], + TypingPosition: positions[1], + ReceiptPosition: positions[2], + SendToDevicePosition: positions[3], + InvitePosition: positions[4], + AccountDataPosition: positions[5], + DeviceListPosition: positions[6], + NotificationDataPosition: positions[7], } return token, nil } @@ -383,6 +391,10 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` + UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` + } `json:"unread_notifications"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -462,3 +474,16 @@ type Peek struct { New bool Deleted bool } + +type ReadUpdate struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + Read StreamPosition `json:"read,omitempty"` + FullyRead StreamPosition `json:"fully_read,omitempty"` +} + +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. +type StreamedEvent struct { + Event *gomatrixserverlib.HeaderedEvent `json:"event"` + StreamPosition StreamPosition `json:"stream_position"` +} diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index cda178b37..ff78bfb9d 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -9,10 +9,10 @@ import ( func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), - "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), - "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(), + "s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(), + "s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass { diff --git a/sytest-blacklist b/sytest-blacklist index e8617dcdf..becc500ec 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,9 +24,15 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes -/context/ with lazy_load_members filter works # we don't support groups Remove group category Remove group role +# Flakey +AS-ghosted users can use rooms themselves + +# Flakey, need additional investigation +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count +Notifications can be viewed with GET /notifications \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 12522cfb3..63d779bf1 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -78,7 +78,7 @@ Room creation reports m.room.member to myself Outbound federation rejects send_join responses with no m.room.create event Outbound federation rejects m.room.create events with an unknown room version Invited user can see room metadata -# Blacklisted because these tests call /r0/events which we don't implement +# Blacklisted because these tests call /v3/events which we don't implement # New room members see their own join event # Existing members see new members' join events setting 'm.room.power_levels' respects room powerlevel @@ -257,7 +257,7 @@ Guest non-joined users cannot send messages to guest_access rooms if not joined Real non-joined users cannot room initalSync for non-world_readable rooms Push rules come down in an initial /sync Regular users can add and delete aliases in the default room configuration -GET /r0/capabilities is not public +GET /v3/capabilities is not public GET /joined_rooms lists newly-created room /joined_rooms returns only joined rooms Message history can be paginated over federation @@ -339,17 +339,17 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device -Can send a message directly to a device using PUT /sendToDevice -Can recv a device message using /sync -Can recv device messages until they are acknowledged -Device messages with the same txn_id are deduplicated -Device messages wake up /sync -Can recv device messages over federation -Device messages over federation wake up /sync -Can send messages with a wildcard device id -Can send messages with a wildcard device id to two devices -Wildcard device messages wake up /sync -Wildcard device messages over federation wake up /sync +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +Can recv device messages over federation +Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +Wildcard device messages over federation wake up /sync Can send a to-device message to two users which both receive it using /sync User can create and send/receive messages in a room with version 6 local user can join room with version 6 @@ -366,8 +366,8 @@ Outbound federation will ignore a missing event with bad JSON for room version 6 Server correctly handles transactions that break edu limits Server rejects invalid JSON in a version 6 room Can download without a file name over federation -POST /media/r0/upload can create an upload -GET /media/r0/download can fetch the value again +POST /media/v3/upload can create an upload +GET /media/v3/download can fetch the value again Remote users can join room by alias Alias creators can delete alias with no ops Alias creators can delete canonical alias with no ops @@ -477,7 +477,7 @@ Federation key API can act as a notary server via a GET request Inbound /make_join rejects attempts to join rooms where all users have left Inbound federation rejects invites which include invalid JSON for room version 6 Inbound federation rejects invite rejections which include invalid JSON for room version 6 -GET /capabilities is present and well formed for registered user +GET /capabilities is present and well formed for registered user m.room.history_visibility == "joined" allows/forbids appropriately for Guest users m.room.history_visibility == "joined" allows/forbids appropriately for Real users POST rejects invalid utf-8 in JSON @@ -588,6 +588,58 @@ User can invite remote user to room with version 9 Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 +Pushers created with a different access token are deleted on password change +Pushers created with a the same access token are not deleted on password change +Can fetch a user's pushers +Can add global push rule for room +Can add global push rule for sender +Can add global push rule for content +Can add global push rule for override +Can add global push rule for underride +Can add global push rule for content +New rules appear before old rules by default +Can add global push rule before an existing rule +Can add global push rule after an existing rule +Can delete a push rule +Can disable a push rule +Adding the same push rule twice is idempotent +Can change the actions of default rules +Can change the actions of a user specified rule +Adding a push rule wakes up an incremental /sync +Disabling a push rule wakes up an incremental /sync +Enabling a push rule wakes up an incremental /sync +Setting actions for a push rule wakes up an incremental /sync +Can enable/disable default rules +Trying to add push rule with missing template fails with 400 +Trying to add push rule with missing rule_id fails with 400 +Trying to add push rule with empty rule_id fails with 400 +Trying to add push rule with invalid template fails with 400 +Trying to add push rule with rule_id with slashes fails with 400 +Trying to add push rule with override rule without conditions fails with 400 +Trying to add push rule with underride rule without conditions fails with 400 +Trying to add push rule with condition without kind fails with 400 +Trying to add push rule with content rule without pattern fails with 400 +Trying to add push rule with no actions fails with 400 +Trying to add push rule with invalid action fails with 400 +Trying to add push rule with invalid attr fails with 400 +Trying to add push rule with invalid value for enabled fails with 400 +Trying to get push rules with no trailing slash fails with 400 +Trying to get push rules with scope without trailing slash fails with 400 +Trying to get push rules with template without tailing slash fails with 400 +Trying to get push rules with unknown scope fails with 400 +Trying to get push rules with unknown template fails with 400 +Trying to get push rules with unknown attribute fails with 400 +Getting push rules doesn't corrupt the cache SYN-390 +Test that a message is pushed +Invites are pushed +Rooms with names are correctly named in pushes +Rooms with canonical alias are correctly named in pushed +Rooms with many users are correctly pushed +Don't get pushed for rooms you've muted +Rejected events are not pushed +Test that rejected pushers are removed. +Trying to add push rule with no scope fails with 400 +Trying to add push rule with invalid scope fails with 400 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation @@ -597,6 +649,7 @@ Device list doesn't change if remote server is down /context/ on non world readable room does not work /context/ returns correct number of events /context/ with lazy_load_members filter works +GET /rooms/:room_id/messages lazy loads members correctly Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves @@ -604,4 +657,6 @@ Remote banned user is kicked and may not rejoin until unbanned registration remembers parameters registration accepts non-ascii passwords registration with inhibit_login inhibits login - +The operation must be consistent through an interactive authentication session +Multiple calls to /sync should not cause 500 errors +/context/ with lazy_load_members filter works diff --git a/userapi/api/api.go b/userapi/api/api.go index 6d8927043..244d13bb1 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" ) // UserInternalAPI is the internal API for information about users and devices. @@ -28,6 +29,7 @@ type UserInternalAPI interface { LoginTokenInternalAPI InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error @@ -37,6 +39,10 @@ type UserInternalAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error + PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error + PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -45,6 +51,9 @@ type UserInternalAPI interface { QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error + QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error + QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error @@ -456,3 +465,77 @@ const ( // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 ) + +type QueryPushersRequest struct { + Localpart string +} + +type QueryPushersResponse struct { + Pushers []Pusher `json:"pushers"` +} + +type PerformPusherSetRequest struct { + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + Append bool `json:"append"` +} + +type PerformPusherDeletionRequest struct { + Localpart string + SessionID int64 +} + +// Pusher represents a push notification subscriber +type Pusher struct { + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` +} + +type PusherKind string + +const ( + EmailKind PusherKind = "email" + HTTPKind PusherKind = "http" +) + +type PerformPushRulesPutRequest struct { + UserID string `json:"user_id"` + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryPushRulesRequest struct { + UserID string `json:"user_id"` +} + +type QueryPushRulesResponse struct { + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryNotificationsRequest struct { + Localpart string `json:"localpart"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` +} + +type QueryNotificationsResponse struct { + NextToken string `json:"next_token"` + Notifications []*Notification `json:"notifications"` // Required. +} + +type Notification struct { + Actions []*pushrules.Action `json:"actions"` // Required. + Event gomatrixserverlib.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS gomatrixserverlib.Timestamp `json:"ts"` // Required. +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 9d142d34b..b8bde4340 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { + err := t.Impl.PerformPusherSet(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { + err := t.Impl.PerformPusherDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { + err := t.Impl.PerformPushRulesPut(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) + return err +} func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) @@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { + err := t.Impl.QueryPushers(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { + err := t.Impl.QueryPushRules(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { + err := t.Impl.QueryNotifications(ctx, req, res) + util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) + return err +} func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error { err := t.Impl.QueryPolicyVersion(ctx, req, res) diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go new file mode 100644 index 000000000..2e58020b4 --- /dev/null +++ b/userapi/consumers/syncapi_readupdate.go @@ -0,0 +1,136 @@ +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputReadUpdateConsumer struct { + ctx context.Context + cfg *config.UserAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + pgClient pushgateway.Client + ServerName gomatrixserverlib.ServerName + topic string + userAPI uapi.UserInternalAPI + syncProducer *producers.SyncAPI +} + +func NewOutputReadUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI uapi.UserInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputReadUpdateConsumer { + return &OutputReadUpdateConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + ServerName: cfg.Matrix.ServerName, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + pgClient: pgClient, + userAPI: userAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputReadUpdateConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var read types.ReadUpdate + if err := json.Unmarshal(msg.Data, &read); err != nil { + log.WithError(err).Error("userapi clientapi consumer: message parse failure") + return true + } + if read.FullyRead == 0 && read.Read == 0 { + return true + } + + userID := string(msg.Header.Get(jetstream.UserID)) + roomID := string(msg.Header.Get(jetstream.RoomID)) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.ServerName { + log.Error("userapi clientapi consumer: not a local user") + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + log.Tracef("Received read update from sync API: %#v", read) + + if read.Read > 0 { + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if updated { + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + } + } + + if read.FullyRead > 0 { + deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) + if err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") + return false + } + + if deleted { + if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") + return false + } + + if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") + return false + } + } + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go new file mode 100644 index 000000000..110813274 --- /dev/null +++ b/userapi/consumers/syncapi_streamevent.go @@ -0,0 +1,588 @@ +package consumers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputStreamEventConsumer struct { + ctx context.Context + cfg *config.UserAPI + userAPI api.UserInternalAPI + rsAPI rsapi.RoomserverInternalAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + topic string + pgClient pushgateway.Client + syncProducer *producers.SyncAPI +} + +func NewOutputStreamEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI api.UserInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputStreamEventConsumer { + return &OutputStreamEventConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + pgClient: pgClient, + userAPI: userAPI, + rsAPI: rsAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputStreamEventConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output types.StreamedEvent + output.Event = &gomatrixserverlib.HeaderedEvent{} + if err := json.Unmarshal(msg.Data, &output); err != nil { + log.WithError(err).Errorf("userapi consumer: message parse failure") + return true + } + if output.Event.Event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + "event_type": output.Event.Type(), + "stream_pos": output.StreamPosition, + }).Tracef("Received message from sync API: %#v", output) + + if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + }).WithError(err).Errorf("userapi consumer: process room event failure") + } + + return true +} + +func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + if err != nil { + return fmt.Errorf("s.localRoomMembers: %w", err) + } + + if event.Type() == gomatrixserverlib.MRoomMember { + cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + var member *localMembership + member, err = newLocalMembership(&cevent) + if err != nil { + return fmt.Errorf("newLocalMembership: %w", err) + } + if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + // localRoomMembers only adds joined members. An invite + // should also be pushed to the target user. + members = append(members, member) + } + } + + // TODO: run in parallel with localRoomMembers. + roomName, err := s.roomName(ctx, event) + if err != nil { + return fmt.Errorf("s.roomName: %w", err) + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "num_members": len(members), + "room_size": roomSize, + }).Tracef("Notifying members") + + // Notification.UserIsTarget is a per-member field, so we + // cannot group all users in a single request. + // + // TODO: does it have to be set? It's not required, and + // removing it means we can send all notifications to + // e.g. Element's Push gateway in one go. + for _, mem := range members { + if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + log.WithFields(log.Fields{ + "localpart": mem.Localpart, + }).WithError(err).Debugf("Unable to push to local user") + continue + } + } + + return nil +} + +type localMembership struct { + gomatrixserverlib.MemberContent + UserID string + Localpart string + Domain gomatrixserverlib.ServerName +} + +func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { + if event.StateKey == nil { + return nil, fmt.Errorf("missing state_key") + } + + var member localMembership + if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { + return nil, err + } + + localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + return nil, err + } + + member.UserID = *event.StateKey + member.Localpart = localpart + member.Domain = domain + return &member, nil +} + +// localRoomMembers fetches the current local members of a room, and +// the total number of members. +func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + req := &rsapi.QueryMembershipsForRoomRequest{ + RoomID: roomID, + JoinedOnly: true, + } + var res rsapi.QueryMembershipsForRoomResponse + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { + return nil, 0, err + } + + var members []*localMembership + var ntotal int + for _, event := range res.JoinEvents { + member, err := newLocalMembership(&event) + if err != nil { + log.WithError(err).Errorf("Parsing MemberContent") + continue + } + if member.Membership != gomatrixserverlib.Join { + continue + } + if member.Domain != s.cfg.Matrix.ServerName { + continue + } + + ntotal++ + members = append(members, member) + } + + return members, ntotal, nil +} + +// roomName returns the name in the event (if type==m.room.name), or +// looks it up in roomserver. If there is no name, +// m.room.canonical_alias is consulted. Returns an empty string if the +// room has no name. +func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { + if event.Type() == gomatrixserverlib.MRoomName { + name, err := unmarshalRoomName(event) + if err != nil { + return "", err + } + + if name != "" { + return name, nil + } + } + + req := &rsapi.QueryCurrentStateRequest{ + RoomID: event.RoomID(), + StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, + } + var res rsapi.QueryCurrentStateResponse + + if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { + return "", nil + } + + if eventS := res.StateEvents[roomNameTuple]; eventS != nil { + return unmarshalRoomName(eventS) + } + + if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + alias, err := unmarshalCanonicalAlias(event) + if err != nil { + return "", err + } + + if alias != "" { + return alias, nil + } + } + + if event = res.StateEvents[canonicalAliasTuple]; event != nil { + return unmarshalCanonicalAlias(event) + } + + return "", nil +} + +var ( + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} +) + +func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var nc eventutil.NameContent + if err := json.Unmarshal(event.Content(), &nc); err != nil { + return "", fmt.Errorf("unmarshaling NameContent: %w", err) + } + + return nc.Name, nil +} + +func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var cac eventutil.CanonicalAliasContent + if err := json.Unmarshal(event.Content(), &cac); err != nil { + return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) + } + + return cac.Alias, nil +} + +// notifyLocal finds the right push actions for a local user, given an event. +func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { + actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) + if err != nil { + return err + } + a, tweaks, err := pushrules.ActionsToTweaks(actions) + if err != nil { + return err + } + // TODO: support coalescing. + if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + }).Tracef("Push rule evaluation rejected the event") + return nil + } + + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + if err != nil { + return err + } + + n := &api.Notification{ + Actions: actions, + // UNSPEC: the spec doesn't say this is a ClientEvent, but the + // fields seem to match. room_id should be missing, which + // matches the behaviour of FormatSync. + Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + // TODO: this is per-device, but it's not part of the primary + // key. So inserting one notification per profile tag doesn't + // make sense. What is this supposed to be? Sytests require it + // to "work", but they only use a single device. + ProfileTag: profileTag, + RoomID: event.RoomID(), + TS: gomatrixserverlib.AsTimestamp(time.Now()), + } + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + return err + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + return err + } + + // We do this after InsertNotification. Thus, this should always return >=1. + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "num_urls": len(devicesByURLAndFormat), + "num_unread": userNumUnreadNotifs, + }).Tracef("Notifying single member") + + // Push gateways are out of our control, and we cannot risk + // looking up the server on a misbehaving push gateway. Each user + // receives a goroutine now that all internal API calls have been + // made. + // + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var rejected []*pushgateway.Device + for url, fmts := range devicesByURLAndFormat { + for format, devices := range fmts { + // TODO: support "email". + if !strings.HasPrefix(url, "http") { + continue + } + + // UNSPEC: the specification suggests there can be + // more than one device per request. There is at least + // one Sytest that expects one HTTP request per + // device, rather than per URL. For now, we must + // notify each one separately. + for _, dev := range devices { + rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) + if err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "localpart": mem.Localpart, + }).WithError(err).Errorf("Unable to notify HTTP pusher") + continue + } + rejected = append(rejected, rej...) + } + } + } + + if len(rejected) > 0 { + s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + } + }() + + return nil +} + +// evaluatePushRules fetches and evaluates the push rules of a local +// user. Returns actions (including dont_notify). +func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + if event.Sender() == mem.UserID { + // SPEC: Homeservers MUST NOT notify the Push Gateway for + // events that the user has sent themselves. + return nil, nil + } + + var res api.QueryPushRulesResponse + if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil { + return nil, err + } + + ec := &ruleSetEvalContext{ + ctx: ctx, + rsAPI: s.rsAPI, + mem: mem, + roomID: event.RoomID(), + roomSize: roomSize, + } + eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global) + rule, err := eval.MatchEvent(event.Event) + if err != nil { + return nil, err + } + if rule == nil { + // SPEC: If no rules match an event, the homeserver MUST NOT + // notify the Push Gateway for that event. + return nil, err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "rule_id": rule.RuleID, + }).Tracef("Matched a push rule") + + return rule.Actions, nil +} + +type ruleSetEvalContext struct { + ctx context.Context + rsAPI rsapi.RoomserverInternalAPI + mem *localMembership + roomID string + roomSize int +} + +func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } + +func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } + +func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { + req := &rsapi.QueryLatestEventsAndStateRequest{ + RoomID: rse.roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels}, + }, + } + var res rsapi.QueryLatestEventsAndStateResponse + if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { + return false, err + } + for _, ev := range res.StateEvents { + if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + continue + } + + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + if err != nil { + return false, err + } + return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + } + return true, nil +} + +// localPushDevices pushes to the configured devices of a local +// user. The map keys are [url][format]. +func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) + if err != nil { + return nil, "", err + } + + var profileTag string + devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) + for _, pusherDevice := range pusherDevices { + if profileTag == "" { + profileTag = pusherDevice.Pusher.ProfileTag + } + + url := pusherDevice.URL + if devicesByURL[url] == nil { + devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) + } + devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) + } + + return devicesByURL, profileTag, nil +} + +// notifyHTTP performs a notificatation to a Push Gateway. +func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { + logger := log.WithFields(log.Fields{ + "event_id": event.EventID(), + "url": url, + "localpart": localpart, + "num_devices": len(devices), + }) + + var req pushgateway.NotifyRequest + switch format { + case "event_id_only": + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{}, + Devices: devices, + EventID: event.EventID(), + RoomID: event.RoomID(), + }, + } + + default: + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Content: event.Content(), + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + ID: event.EventID(), + RoomID: event.RoomID(), + RoomName: roomName, + Sender: event.Sender(), + Type: event.Type(), + }, + } + if mem, err := event.Membership(); err == nil { + req.Notification.Membership = mem + } + if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + req.Notification.UserIsTarget = true + } + } + + logger.Debugf("Notifying push gateway %s", url) + var res pushgateway.NotifyResponse + if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { + logger.WithError(err).Errorf("Failed to notify push gateway %s", url) + return nil, err + } + logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + + if len(res.Rejected) == 0 { + return nil, nil + } + + devMap := make(map[string]*pushgateway.Device, len(devices)) + for _, d := range devices { + devMap[d.PushKey] = d + } + rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) + for _, pushKey := range res.Rejected { + d := devMap[pushKey] + if d != nil { + rejected = append(rejected, d) + } + } + + return rejected, nil +} + +// deleteRejectedPushers deletes the pushers associated with the given devices. +func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": devices[0].AppID, + "num_devices": len(devices), + }).Warnf("Deleting pushers rejected by the HTTP push gateway") + + for _, d := range devices { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + }).WithError(err).Errorf("Unable to delete rejected pusher") + } + } +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 8ef6a0900..dafbf2180 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -27,16 +29,22 @@ import ( "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type UserInternalAPI struct { - DB storage.Database - ServerName gomatrixserverlib.ServerName + DB storage.Database + SyncProducer *producers.SyncAPI + + DisableTLSValidation bool + ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.KeyInternalAPI @@ -596,6 +604,165 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB res.Keys = result } +func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + if req.Limit == 0 || req.Limit > 1000 { + req.Limit = 1000 + } + + var fromID int64 + var err error + if req.From != "" { + fromID, err = strconv.ParseInt(req.From, 10, 64) + if err != nil { + return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) + } + } + var filter tables.NotificationFilter = tables.AllNotifications + if req.Only == "highlight" { + filter = tables.HighlightNotifications + } + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + if err != nil { + return err + } + if notifs == nil { + // This ensures empty is JSON-encoded as [] instead of null. + notifs = []*api.Notification{} + } + res.Notifications = notifs + if lastID >= 0 { + res.NextToken = strconv.FormatInt(lastID+1, 10) + } + return nil +} + +func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.Pusher.PushKey, + "display_name": req.Pusher.AppDisplayName, + }).Info("PerformPusherCreation") + if !req.Append { + err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) + if err != nil { + return err + } + } + if req.Pusher.Kind == "" { + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + } + if req.Pusher.PushKeyTS == 0 { + req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + } + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) +} + +func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + pushers, err := a.DB.GetPushers(ctx, req.Localpart) + if err != nil { + return err + } + for i := range pushers { + logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) + if pushers[i].SessionID != req.SessionID { + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + if err != nil { + return err + } + } + } + return nil +} + +func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + var err error + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + return err +} + +func (a *UserInternalAPI) PerformPushRulesPut( + ctx context.Context, + req *api.PerformPushRulesPutRequest, + _ *struct{}, +) error { + bs, err := json.Marshal(&req.RuleSets) + if err != nil { + return err + } + userReq := api.InputAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + AccountData: json.RawMessage(bs), + } + var userRes api.InputAccountDataResponse // empty + if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + + if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed") + } + + return nil +} + +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + userReq := api.QueryAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + } + var userRes api.QueryAccountDataResponse + if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType] + if ok { + // Legacy Dendrite users will have completely empty push rules, so we should + // detect that situation and set some defaults. + var rules struct { + G struct { + Content []json.RawMessage `json:"content"` + Override []json.RawMessage `json:"override"` + Room []json.RawMessage `json:"room"` + Sender []json.RawMessage `json:"sender"` + Underride []json.RawMessage `json:"underride"` + } `json:"global"` + } + if err := json.Unmarshal([]byte(bs), &rules); err == nil { + count := len(rules.G.Content) + len(rules.G.Override) + + len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride) + ok = count > 0 + } + } + if !ok { + // If we didn't find any default push rules then we should just generate some + // fresh ones. + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + } + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return fmt.Errorf("failed to marshal default push rules: %w", err) + } + if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil { + return fmt.Errorf("failed to save default push rules: %w", err) + } + res.RuleSets = pushRuleSets + return nil + } + var data pushrules.AccountRuleSets + if err := json.Unmarshal([]byte(bs), &data); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed") + return err + } + res.RuleSets = &data + return nil +} + +const pushRulesAccountDataType = "m.push_rules" + func (a *UserInternalAPI) QueryPolicyVersion( ctx context.Context, req *api.QueryPolicyVersionRequest, diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index b5176aa3e..59be975cd 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,16 +37,22 @@ const ( PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformPusherSetPath = "/pushserver/performPusherSet" + PerformPusherDeletionPath = "/pushserver/performPusherDeletion" + PerformPushRulesPutPath = "/pushserver/performPushRulesPut" PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" - QueryKeyBackupPath = "/userapi/queryKeyBackup" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" - QuerySearchProfilesPath = "/userapi/querySearchProfiles" - QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryKeyBackupPath = "/userapi/queryKeyBackup" + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPushersPath = "/pushserver/queryPushers" + QueryPushRulesPath = "/pushserver/queryPushRules" + QueryNotificationsPath = "/pushserver/queryNotifications" QueryPolicyVersionPath = "/userapi/queryPolicyVersion" QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy" ) @@ -253,6 +259,61 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query } } +func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") + defer span.Finish() + + return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +} + +func (h *httpUserInternalAPI) PerformPusherSet( + ctx context.Context, + request *api.PerformPusherSetRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherSetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") + defer span.Finish() + + apiURL := h.apiURL + QueryPushersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformPushRulesPut( + ctx context.Context, + request *api.PerformPushRulesPutRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") + defer span.Finish() + + apiURL := h.apiURL + PerformPushRulesPutPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") + defer span.Finish() + + apiURL := h.apiURL + QueryPushRulesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error { span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion") defer span.Finish() diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 8dc0dee55..66e2c6fd0 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -265,6 +265,88 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryNotificationsPath, + httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { + var request api.QueryNotificationsRequest + var response api.QueryNotificationsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPusherSetPath, + httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherSetRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformPusherDeletionPath, + httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherDeletionRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushersPath, + httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { + request := api.QueryPushersRequest{} + response := api.QueryPushersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPushRulesPutPath, + httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { + request := api.PerformPushRulesPutRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushRulesPath, + httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { + request := api.QueryPushRulesRequest{} + response := api.QueryPushRulesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryPolicyVersionPath, httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse { request := api.QueryPolicyVersionRequest{} diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go new file mode 100644 index 000000000..4a206f333 --- /dev/null +++ b/userapi/producers/syncapi.go @@ -0,0 +1,104 @@ +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type JetStreamPublisher interface { + PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) +} + +// SyncAPI produces messages for the Sync API server to consume. +type SyncAPI struct { + db storage.Database + producer JetStreamPublisher + clientDataTopic string + notificationDataTopic string +} + +func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { + return &SyncAPI{ + db: db, + producer: js, + clientDataTopic: clientDataTopic, + notificationDataTopic: notificationDataTopic, + } +} + +// SendAccountData sends account data to the Sync API server. +func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error { + m := &nats.Msg{ + Subject: p.clientDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(eventutil.AccountData{ + RoomID: roomID, + Type: dataType, + }) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} + +// GetAndSendNotificationData reads the database and sends data about unread +// notifications to the Sync API server. +func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + if err != nil { + return err + } + + return p.sendNotificationData(userID, &eventutil.NotificationData{ + RoomID: roomID, + UnreadHighlightCount: int(nhighlight), + UnreadNotificationCount: int(ntotal), + }) +} + +// sendNotificationData sends data about unread notifications to the Sync API server. +func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error { + m := &nats.Msg{ + Subject: p.notificationDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 39164fed0..aca669820 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type Database interface { @@ -92,6 +93,19 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + DeleteOldNotifications(ctx context.Context) error + + UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error + GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + RemovePushers(ctx context.Context, appid, pushkey string) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go new file mode 100644 index 000000000..a27c1125e --- /dev/null +++ b/userapi/storage/postgres/notifications_table.go @@ -0,0 +1,235 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt + cleanNotificationsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id BIGSERIAL PRIMARY KEY, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +const cleanNotificationsSQL = "" + + "DELETE FROM userapi_notifications WHERE" + + " (highlight = FALSE AND ts_ms < $1) OR (highlight = TRUE AND ts_ms < $2)" + +func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + {&s.cleanNotificationsStmt, cleanNotificationsSQL}, + }.Prepare(db) +} + +func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.cleanNotificationsStmt).ExecContext( + ctx, + time.Now().AddDate(0, 0, -1).UnixNano()/int64(time.Millisecond), // keep non-highlights for a day + time.Now().AddDate(0, -1, 0).UnixNano()/int64(time.Millisecond), // keep highlights for a month + ) + return err +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go new file mode 100644 index 000000000..670dc916f --- /dev/null +++ b/userapi/storage/postgres/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id BIGSERIAL PRIMARY KEY, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 28b43a01d..4d85190de 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) } + pusherTable, err := NewPostgresPusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewPostgresNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 085a4e207..262531dcb 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -47,6 +48,8 @@ type Database struct { KeyBackupVersions tables.KeyBackupVersionTable Devices tables.DevicesTable LoginTokens tables.LoginTokenTable + Notifications tables.NotificationTable + Pushers tables.PusherTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -160,15 +163,12 @@ func (d *Database) createAccount( if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { return nil, err } return account, nil @@ -671,6 +671,101 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +} + +func (d *Database) DeleteOldNotifications(ctx context.Context) error { + return d.Notifications.Clean(ctx, nil) +} + +func (d *Database) UpsertPusher( + ctx context.Context, p api.Pusher, localpart string, +) error { + data, err := json.Marshal(p.Data) + if err != nil { + return err + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Pushers.InsertPusher( + ctx, txn, + p.SessionID, + p.PushKey, + p.PushKeyTS, + p.Kind, + p.AppID, + p.AppDisplayName, + p.DeviceDisplayName, + p.ProfileTag, + p.Language, + string(data), + localpart) + }) +} + +// GetPushers returns the pushers matching the given localpart. +func (d *Database) GetPushers( + ctx context.Context, localpart string, +) ([]api.Pusher, error) { + return d.Pushers.SelectPushers(ctx, nil, localpart) +} + +// RemovePusher deletes one pusher +// Invoked when `append` is true and `kind` is null in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePusher( + ctx context.Context, appid, pushkey, localpart string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + if err == sql.ErrNoRows { + return nil + } + return err + }) +} + +// RemovePushers deletes all pushers that match given App Id and Push Key pair. +// Invoked when `append` parameter is false in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePushers( + ctx context.Context, appid, pushkey string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) + }) +} + // GetPrivacyPolicy returns the accepted privacy policy version, if any. func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go new file mode 100644 index 000000000..df8260251 --- /dev/null +++ b/userapi/storage/sqlite3/notifications_table.go @@ -0,0 +1,235 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt + cleanNotificationsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +const cleanNotificationsSQL = "" + + "DELETE FROM userapi_notifications WHERE" + + " (highlight = FALSE AND ts_ms < $1) OR (highlight = TRUE AND ts_ms < $2)" + +func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + {&s.cleanNotificationsStmt, cleanNotificationsSQL}, + }.Prepare(db) +} + +func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.cleanNotificationsStmt).ExecContext( + ctx, + time.Now().AddDate(0, 0, -1).UnixNano()/int64(time.Millisecond), // keep non-highlights for a day + time.Now().AddDate(0, -1, 0).UnixNano()/int64(time.Millisecond), // keep highlights for a month + ) + return err +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go new file mode 100644 index 000000000..e718792e1 --- /dev/null +++ b/userapi/storage/sqlite3/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 44a79c94e..5c52f6ac6 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -87,6 +87,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) } + pusherTable, err := NewSQLitePusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewSQLiteNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -97,6 +105,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index a126086b1..ef067ed07 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -97,3 +98,43 @@ type ThreePIDTable interface { InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } + +type PusherTable interface { + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error +} + +type NotificationTable interface { + Clean(ctx context.Context, txn *sql.Tx) error + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +) diff --git a/userapi/userapi.go b/userapi/userapi.go index 4a5793abb..251a4edad 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,11 +18,17 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) @@ -36,26 +42,54 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI, + appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { - db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to device db") + js := jetstream.Prepare(&cfg.Matrix.JetStream) + + syncProducer := producers.NewSyncAPI( + db, js, + // TODO: user API should handle syncs for account data. Right now, + // it's handled by clientapi, and hence uses its topic. When user + // API handles it for all account data, we can remove it from + // here. + cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + ) + + userAPI := &internal.UserInternalAPI{ + DB: db, + SyncProducer: syncProducer, + ServerName: cfg.Matrix.ServerName, + AppServices: appServices, + KeyAPI: keyAPI, + DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, } - return newInternalAPI(db, cfg, appServices, keyAPI) -} - -func newInternalAPI( - db storage.Database, - cfg *config.UserAPI, - appServices []config.ApplicationService, - keyAPI keyapi.KeyInternalAPI, -) api.UserInternalAPI { - return &internal.UserInternalAPI{ - DB: db, - ServerName: cfg.Matrix.ServerName, - AppServices: appServices, - KeyAPI: keyAPI, + readConsumer := consumers.NewOutputReadUpdateConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + ) + if err := readConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API read update consumer") } + + eventConsumer := consumers.NewOutputStreamEventConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer, + ) + if err := eventConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API streamed event consumer") + } + + var cleanOldNotifs func() + cleanOldNotifs = func() { + logrus.Infof("Cleaning old notifications") + if err := db.DeleteOldNotifications(base.Context()); err != nil { + logrus.WithError(err).Error("Failed to clean old notifications") + } + time.AfterFunc(time.Hour, cleanOldNotifs) + } + time.AfterFunc(time.Minute, cleanOldNotifs) + + return userAPI } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index fac6aefbe..1f73a6176 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s }, } - return newInternalAPI(accountDB, cfg, nil, nil), accountDB + return &internal.UserInternalAPI{ + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB } func TestQueryProfile(t *testing.T) { diff --git a/userapi/util/devices.go b/userapi/util/devices.go new file mode 100644 index 000000000..cbf3bd28f --- /dev/null +++ b/userapi/util/devices.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + log "github.com/sirupsen/logrus" +) + +type PusherDevice struct { + Device pushgateway.Device + Pusher *api.Pusher + URL string + Format string +} + +// GetPushDevices pushes to the configured devices of a local user. +func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart) + if err != nil { + return nil, err + } + + devices := make([]*PusherDevice, 0, len(pushers)) + for _, pusher := range pushers { + var url, format string + data := pusher.Data + switch pusher.Kind { + case api.EmailKind: + url = "mailto:" + + case api.HTTPKind: + // TODO: The spec says only event_id_only is supported, + // but Sytests assume "" means "full notification". + fmtIface := pusher.Data["format"] + var ok bool + format, ok = fmtIface.(string) + if ok && format != "event_id_only" { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("Only data.format event_id_only or empty is supported") + continue + } + + urlIface := pusher.Data["url"] + url, ok = urlIface.(string) + if !ok { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("No data.url configured for HTTP Pusher") + continue + } + data = mapWithout(data, "url") + + default: + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + "kind": pusher.Kind, + }).Errorf("Unhandled pusher kind") + continue + } + + devices = append(devices, &PusherDevice{ + Device: pushgateway.Device{ + AppID: pusher.AppID, + Data: data, + PushKey: pusher.PushKey, + PushKeyTS: pusher.PushKeyTS, + Tweaks: tweaks, + }, + Pusher: &pusher, + URL: url, + Format: format, + }) + } + + return devices, nil +} + +// mapWithout returns a shallow copy of the map, without the given +// key. Returns nil if the resulting map is empty. +func mapWithout(m map[string]interface{}, key string) map[string]interface{} { + ret := make(map[string]interface{}, len(m)) + for k, v := range m { + // The specification says we do not send "url". + if k == key { + continue + } + ret[k] = v + } + if len(ret) == 0 { + return nil + } + return ret +} diff --git a/userapi/util/notify.go b/userapi/util/notify.go new file mode 100644 index 000000000..ff206bd3c --- /dev/null +++ b/userapi/util/notify.go @@ -0,0 +1,76 @@ +package util + +import ( + "context" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +// NotifyUserCountsAsync sends notifications to a local user's +// notification destinations. Database lookups run synchronously, but +// a single goroutine is started when talking to the Push +// gateways. There is no way to know when the background goroutine has +// finished. +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) + if err != nil { + return err + } + + if len(pusherDevices) == 0 { + return nil + } + + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevices[0].Device.AppID, + "pushkey": pusherDevices[0].Device.PushKey, + }).Tracef("Notifying HTTP push gateway about notification counts") + + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // TODO: we could batch all devices with the same URL, but + // Sytest requires consumers/roomserver.go to do it + // one-by-one, so we do the same here. + for _, pusherDevice := range pusherDevices { + // TODO: support "email". + if !strings.HasPrefix(pusherDevice.URL, "http") { + continue + } + + req := pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{ + Unread: int(userNumUnreadNotifs), + }, + Devices: []*pushgateway.Device{&pusherDevice.Device}, + }, + } + if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevice.Device.AppID, + "pushkey": pusherDevice.Device.PushKey, + }).WithError(err).Error("HTTP push gateway request failed") + return + } + } + }() + + return nil +}