diff --git a/.gitignore b/.gitignore index 092f4501c..2a8c2cf55 100644 --- a/.gitignore +++ b/.gitignore @@ -54,7 +54,7 @@ dendrite.yaml *.db # Log files -*.log* +*.log* # Generated code cmd/dendrite-demo-yggdrasil/embed/fs*.go @@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go # Test dependencies test/wasm/node_modules -media_store/ +# Ignore complement folder when running locally +complement/ +media_store/ diff --git a/build/docker/config/dendrite.yaml b/build/docker/config/dendrite.yaml index 6d5ebc9fd..ebae50132 100644 --- a/build/docker/config/dendrite.yaml +++ b/build/docker/config/dendrite.yaml @@ -318,6 +318,17 @@ user_api: max_idle_conns: 2 conn_max_lifetime: -1 +# Configuration for the Push Server API. +push_server: + internal_api: + listen: http://localhost:7782 + connect: http://localhost:7782 + database: + connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index aa8cc6e6e..5ab90adaf 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 8b9c88f2a..3329485aa 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index a65f3b70d..918476674 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -59,6 +59,7 @@ func AddPublicRoutes( routing.Setup( router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, accountsDB, userAPI, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, + syncProducer, transactionsCache, fsAPI, keyAPI, + extRoomsProvider, mscCfg, ) } diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 9b1d6b1a2..9ab90391d 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -30,7 +30,7 @@ type SyncAPIProducer struct { } // SendData sends account data to the sync API server -func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { +func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error { m := &nats.Msg{ Subject: p.Topic, Header: nats.Header{}, @@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string m.Header.Set(jetstream.UserID, userID) data := eventutil.AccountData{ - RoomID: roomID, - Type: dataType, + RoomID: roomID, + Type: dataType, + ReadMarker: readMarker, } var err error m.Data, err = json.Marshal(data) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 03025f1da..d8e982690 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -127,7 +128,7 @@ func SaveAccountData( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, dataType); err != nil { + if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } @@ -138,11 +139,6 @@ func SaveAccountData( } } -type readMarkerJSON struct { - FullyRead string `json:"m.fully_read"` - Read string `json:"m.read"` -} - type fullyReadEvent struct { EventID string `json:"event_id"` } @@ -159,7 +155,7 @@ func SaveReadMarker( return *resErr } - var r readMarkerJSON + var r eventutil.ReadMarkerJSON resErr = httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -189,7 +185,7 @@ func SaveReadMarker( return util.ErrorResponse(err) } - if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil { + if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go new file mode 100644 index 000000000..ee715d323 --- /dev/null +++ b/clientapi/routing/notification.go @@ -0,0 +1,63 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetNotifications handles /_matrix/client/r0/notifications +func GetNotifications( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var limit int64 + if limitStr := req.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.ParseInt(limitStr, 10, 64) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") + return jsonerror.InternalServerError() + } + } + + var queryRes userapi.QueryNotificationsResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ + Localpart: localpart, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") + return jsonerror.InternalServerError() + } + util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index acac60fa5..c63412d08 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -12,6 +12,7 @@ import ( userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) type newPasswordRequest struct { @@ -37,6 +38,11 @@ func Password( var r newPasswordRequest r.LogoutDevices = true + logrus.WithFields(logrus.Fields{ + "sessionId": device.SessionID, + "userId": device.UserID, + }).Debug("Changing password") + // Unmarshal the request. resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -116,6 +122,15 @@ func Password( util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } + + pushersReq := &api.PerformPusherDeletionRequest{ + Localpart: localpart, + SessionID: device.SessionID, + } + if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") + return jsonerror.InternalServerError() + } } // Return a success code. diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go new file mode 100644 index 000000000..9d6bef8bd --- /dev/null +++ b/clientapi/routing/pusher.go @@ -0,0 +1,114 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "net/url" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetPushers handles /_matrix/client/r0/pushers +func GetPushers( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var queryRes userapi.QueryPushersResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ + Localpart: localpart, + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") + return jsonerror.InternalServerError() + } + for i := range queryRes.Pushers { + queryRes.Pushers[i].SessionID = 0 + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} + +// SetPusher handles /_matrix/client/r0/pushers/set +// This endpoint allows the creation, modification and deletion of pushers for this user ID. +// The behaviour of this endpoint varies depending on the values in the JSON body. +func SetPusher( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + body := userapi.PerformPusherSetRequest{} + if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { + return *resErr + } + if len(body.AppID) > 64 { + return invalidParam("length of app_id must be no more than 64 characters") + } + if len(body.PushKey) > 512 { + return invalidParam("length of pushkey must be no more than 512 bytes") + } + uInt := body.Data["url"] + if uInt != nil { + u, ok := uInt.(string) + if !ok { + return invalidParam("url must be string") + } + if u != "" { + var pushUrl *url.URL + pushUrl, err = url.Parse(u) + if err != nil { + return invalidParam("malformed url passed") + } + if pushUrl.Scheme != "https" { + return invalidParam("only https scheme is allowed") + } + } + + } + body.Localpart = localpart + body.SessionID = device.SessionID + err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func invalidParam(msg string) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam(msg), + } +} diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go new file mode 100644 index 000000000..81a33b25a --- /dev/null +++ b/clientapi/routing/pushrules.go @@ -0,0 +1,386 @@ +package routing + +import ( + "context" + "encoding/json" + "io" + "net/http" + "reflect" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/pushrules" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { + if eerr, ok := err.(*jsonerror.MatrixError); ok { + var status int + switch eerr.ErrCode { + case "M_INVALID_ARGUMENT_VALUE": + status = http.StatusBadRequest + case "M_NOT_FOUND": + status = http.StatusNotFound + default: + status = http.StatusInternalServerError + } + return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) + } + util.GetLogger(ctx).WithError(err).Errorf(msg, args...) + return jsonerror.InternalServerError() +} + +func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSets, + } +} + +func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSet, + } +} + +func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: *rulesPtr, + } +} + +func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: (*rulesPtr)[i], + } +} + +func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newRule); err != nil { + return errorResponse(ctx, err, "JSON Decode failed") + } + newRule.RuleID = ruleID + + errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule) + if len(errs) > 0 { + return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i >= 0 && afterRuleID == "" && beforeRuleID == "" { + // Modify rule at the same index. + + // TODO: The spec does not say what to do in this case, but + // this feels reasonable. + *((*rulesPtr)[i]) = newRule + util.GetLogger(ctx).Infof("Modified existing push rule at %d", i) + } else { + if i >= 0 { + // Delete old rule. + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + util.GetLogger(ctx).Infof("Deleted old push rule at %d", i) + } else { + // SPEC: When creating push rules, they MUST be enabled by default. + // + // TODO: it's unclear if we must reject disabled rules, or force + // the value to true. Sytests fail if we don't force it. + newRule.Enabled = true + } + + // Add new rule. + i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) + if err != nil { + return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") + } + + *rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...) + util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) + } + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]interface{}{ + attr: attrGet((*rulesPtr)[i]), + }, + } +} + +func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newPartialRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + } + if newPartialRule.Actions == nil { + // This ensures json.Marshal encodes the empty list as [] rather than null. + newPartialRule.Actions = []*pushrules.Action{} + } + + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + attrSet, err := pushRuleAttrSetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrSetter failed") + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { + attrSet((*rulesPtr)[i], &newPartialRule) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) { + var res userapi.QueryPushRulesResponse + if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") + return nil, err + } + return res.RuleSets, nil +} + +func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error { + req := userapi.PerformPushRulesPutRequest{ + UserID: userID, + RuleSets: ruleSets, + } + var res struct{} + if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed") + return err + } + return nil +} + +func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { + switch scope { + case pushrules.GlobalScope: + return &ruleSets.Global + default: + return nil + } +} + +func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule { + switch kind { + case pushrules.OverrideKind: + return &ruleSet.Override + case pushrules.ContentKind: + return &ruleSet.Content + case pushrules.RoomKind: + return &ruleSet.Room + case pushrules.SenderKind: + return &ruleSet.Sender + case pushrules.UnderrideKind: + return &ruleSet.Underride + default: + return nil + } +} + +func pushRuleIndexByID(rules []*pushrules.Rule, id string) int { + for i, rule := range rules { + if rule.RuleID == id { + return i + } + } + return -1 +} + +func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) { + switch attr { + case "actions": + return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil + case "enabled": + return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) { + switch attr { + case "actions": + return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil + case "enabled": + return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) { + var i int + + if afterID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == afterID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("after: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("after: rule ID must not be a default rule") + } + // We stopped on the "after" match to differentiate + // not-found from is-last-entry. Now we move to the earliest + // insertion point. + i++ + } + + if beforeID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == beforeID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("before: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("before: rule ID must not be a default rule") + } + } + + // UNSPEC: The spec does not say what to do if no after/before is + // given. Sytest fails if it doesn't go first. + return i, nil +} diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index c683cc949..83294b180 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -98,7 +98,7 @@ func PutTag( return jsonerror.InternalServerError() } - if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } @@ -151,7 +151,7 @@ func DeleteTag( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d75f58b81..d22fbd809 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -16,7 +16,6 @@ package routing import ( "context" - "encoding/json" "net/http" "strings" @@ -561,25 +560,142 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - v3mux.Handle("/pushrules/", - httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { - // TODO: Implement push rules API - res := json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`) + // Push rules + + v3mux.Handle("/pushrules", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ - Code: http.StatusOK, - JSON: &res, + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash"), } }), ).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetAllPushRules(req.Context(), device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + query := req.URL.Query() + return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodDelete) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + // Element user settings v3mux.Handle("/profile/{userID}", @@ -885,6 +1001,27 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/notifications", + httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetNotifications(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers", + httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetPushers(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers/set", + httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SetPusher(req, device, userAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub implementations for sytest v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 78536901c..8ce641914 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -144,12 +144,14 @@ func main() { accountDB := base.Base.CreateAccountsDB() federation := createFederationClient(base) keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( &base.Base, ) + + userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( &base.Base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 5810a7f18..45f186985 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -187,7 +187,7 @@ func main() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index d16f0e9e5..b7e30ba2e 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -111,14 +111,15 @@ func main() { keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( base, ) rsAPI := rsComponent + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index bb2685208..3b952504b 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -106,7 +106,8 @@ func main() { keyAPI = base.KeyServerHTTPClient() } - userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + pgClient := base.PushGatewayHTTPClient() + userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index f147cda14..f1fa379c7 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -23,7 +23,11 @@ import ( func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { accountDB := base.CreateAccountsDB() - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) + userAPI := userapi.NewInternalAPI( + base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, + base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), + base.PushGatewayHTTPClient(), + ) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 664f644f3..407081f59 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -184,13 +184,15 @@ func startup() { accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, pSessions) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() rsAPI := roomserver.NewInternalAPI(base) + + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 0ea41b4c4..37cbb12dd 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -212,6 +212,8 @@ func main() { rsAPI.SetFederationAPI(fedSenderAPI, keyRing) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) + psAPI := pushserver.NewInternalAPI(base) + monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, @@ -225,6 +227,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + PushserverAPI: psAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 533b5c952..0236851c4 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -6,7 +6,7 @@ # # At a minimum, to get started, you will need to update the settings in the # "global" section for your deployment, and you will need to check that the -# database "connection_string" line in each component section is correct. +# database "connection_string" line in each component section is correct. # # Each component with a "database" section can accept the following formats # for "connection_string": @@ -21,13 +21,13 @@ # small number of users and likely will perform worse still with a higher volume # of users. # -# The "max_open_conns" and "max_idle_conns" settings configure the maximum +# The "max_open_conns" and "max_idle_conns" settings configure the maximum # number of open/idle database connections. The value 0 will use the database # engine default, and a negative value will use unlimited connections. The # "conn_max_lifetime" option controls the maximum length of time a database # connection can be idle in seconds - a negative value is unlimited. -# The version of the configuration file. +# The version of the configuration file. version: 2 # Global Matrix configuration. This configuration applies to all components. @@ -61,8 +61,8 @@ global: # Lists of domains that the server will trust as identity servers to verify third # party identifiers such as phone numbers and email addresses. trusted_third_party_id_servers: - - matrix.org - - vector.im + - matrix.org + - vector.im # Disables federation. Dendrite will not be able to make any outbound HTTP requests # to other servers and the federation API will not be exposed. @@ -87,14 +87,14 @@ global: # in monolith mode. It is required to specify the address of at least one # NATS Server node if running in polylith mode. addresses: - # - localhost:4222 + # - localhost:4222 # Keep all NATS streams in memory, rather than persisting it to the storage # path below. This option is present primarily for integration testing and # should not be used on a real world Dendrite deployment. in_memory: false - # Persistent directory to store JetStream streams in. This directory + # Persistent directory to store JetStream streams in. This directory # should be preserved across Dendrite restarts. storage_path: ./ @@ -126,7 +126,7 @@ global: # Configuration for the Appservice API. app_service_api: internal_api: - listen: http://localhost:7777 # Only used in polylith deployments + listen: http://localhost:7777 # Only used in polylith deployments connect: http://localhost:7777 # Only used in polylith deployments database: connection_string: file:appservice.db @@ -145,7 +145,7 @@ app_service_api: # Configuration for the Client API. client_api: internal_api: - listen: http://localhost:7771 # Only used in polylith deployments + listen: http://localhost:7771 # Only used in polylith deployments connect: http://localhost:7771 # Only used in polylith deployments external_api: listen: http://[::]:8071 @@ -165,13 +165,13 @@ client_api: # Whether to require reCAPTCHA for registration. enable_registration_captcha: false - # Settings for ReCAPTCHA. + # Settings for ReCAPTCHA. recaptcha_public_key: "" recaptcha_private_key: "" recaptcha_bypass_secret: "" recaptcha_siteverify_api: "" - # TURN server information that this homeserver should send to clients. + # TURN server information that this homeserver should send to clients. turn: turn_user_lifetime: "" turn_uris: [] @@ -180,7 +180,7 @@ client_api: turn_password: "" # Settings for rate-limited endpoints. Rate limiting will kick in after the - # threshold number of "slots" have been taken by requests from a specific + # threshold number of "slots" have been taken by requests from a specific # host. Each "slot" will be released after the cooloff time in milliseconds. rate_limiting: enabled: true @@ -190,13 +190,13 @@ client_api: # Configuration for the EDU server. edu_server: internal_api: - listen: http://localhost:7778 # Only used in polylith deployments + listen: http://localhost:7778 # Only used in polylith deployments connect: http://localhost:7778 # Only used in polylith deployments # Configuration for the Federation API. federation_api: internal_api: - listen: http://localhost:7772 # Only used in polylith deployments + listen: http://localhost:7772 # Only used in polylith deployments connect: http://localhost:7772 # Only used in polylith deployments external_api: listen: http://[::]:8072 @@ -224,12 +224,12 @@ federation_api: # be required to satisfy key requests for servers that are no longer online when # joining some rooms. key_perspectives: - - server_name: matrix.org - keys: - - key_id: ed25519:auto - public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw - - key_id: ed25519:a_RXGa - public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ # This option will control whether Dendrite will prefer to look up keys directly # or whether it should try perspective servers first, using direct fetches as a @@ -239,7 +239,7 @@ federation_api: # Configuration for the Key Server (for end-to-end encryption). key_server: internal_api: - listen: http://localhost:7779 # Only used in polylith deployments + listen: http://localhost:7779 # Only used in polylith deployments connect: http://localhost:7779 # Only used in polylith deployments database: connection_string: file:keyserver.db @@ -250,7 +250,7 @@ key_server: # Configuration for the Media API. media_api: internal_api: - listen: http://localhost:7774 # Only used in polylith deployments + listen: http://localhost:7774 # Only used in polylith deployments connect: http://localhost:7774 # Only used in polylith deployments external_api: listen: http://[::]:8074 @@ -276,15 +276,15 @@ media_api: # A list of thumbnail sizes to be generated for media content. thumbnail_sizes: - - width: 32 - height: 32 - method: crop - - width: 96 - height: 96 - method: crop - - width: 640 - height: 480 - method: scale + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale # Configuration for experimental MSC's mscs: @@ -302,7 +302,7 @@ mscs: # Configuration for the Room Server. room_server: internal_api: - listen: http://localhost:7770 # Only used in polylith deployments + listen: http://localhost:7770 # Only used in polylith deployments connect: http://localhost:7770 # Only used in polylith deployments database: connection_string: file:roomserver.db @@ -313,7 +313,7 @@ room_server: # Configuration for the Sync API. sync_api: internal_api: - listen: http://localhost:7773 # Only used in polylith deployments + listen: http://localhost:7773 # Only used in polylith deployments connect: http://localhost:7773 # Only used in polylith deployments external_api: listen: http://[::]:8073 @@ -338,16 +338,16 @@ user_api: # This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds) # bcrypt_cost: 10 internal_api: - listen: http://localhost:7781 # Only used in polylith deployments + listen: http://localhost:7781 # Only used in polylith deployments connect: http://localhost:7781 # Only used in polylith deployments account_database: connection_string: file:userapi_accounts.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - # The length of time that a token issued for a relying party from + # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. + # is considered to be valid in milliseconds. # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 @@ -369,10 +369,10 @@ tracing: # Logging configuration logging: -- type: std - level: info -- type: file - # The logging level, must be one of debug, info, warn, error, fatal, panic. - level: info - params: - path: ./logs + - type: std + level: info + - type: file + # The logging level, must be one of debug, info, warn, error, fatal, panic. + level: info + params: + path: ./logs diff --git a/go.mod b/go.mod index dbcae5d54..525950daa 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go index 6d119ce6d..17861d6c5 100644 --- a/internal/eventutil/types.go +++ b/internal/eventutil/types.go @@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID") // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { + RoomID string `json:"room_id"` + Type string `json:"type"` + ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional +} + +type ReadMarkerJSON struct { + FullyRead string `json:"m.fully_read"` + Read string `json:"m.read"` +} + +// NotificationData contains statistics about notifications, sent from +// the Push Server to the Sync API server. +type NotificationData struct { + // RoomID identifies the scope of the statistics, together with + // MXID (which is encoded in the Kafka key). RoomID string `json:"room_id"` - Type string `json:"type"` + + // HighlightCount is the number of unread notifications with the + // highlight tweak. + UnreadHighlightCount int `json:"unread_highlight_count"` + + // UnreadNotificationCount is the total number of unread + // notifications. + UnreadNotificationCount int `json:"unread_notification_count"` } // ProfileResponse is a struct containing all known user profile data diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go new file mode 100644 index 000000000..49907cee8 --- /dev/null +++ b/internal/pushgateway/client.go @@ -0,0 +1,66 @@ +package pushgateway + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/opentracing/opentracing-go" +) + +type httpClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new Push Gateway client. +func NewHTTPClient(disableTLSValidation bool) Client { + hc := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: disableTLSValidation, + }, + }, + } + return &httpClient{hc: hc} +} + +func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") + defer span.Finish() + + body, err := json.Marshal(req) + if err != nil { + return err + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + hreq.Header.Set("Content-Type", "application/json") + + hresp, err := h.hc.Do(hreq) + if err != nil { + return err + } + + //nolint:errcheck + defer hresp.Body.Close() + + if hresp.StatusCode == http.StatusOK { + return json.NewDecoder(hresp.Body).Decode(resp) + } + + var errorBody struct { + Message string `json:"message"` + } + if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil { + return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message) + } + return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url) +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go new file mode 100644 index 000000000..88c326eb2 --- /dev/null +++ b/internal/pushgateway/pushgateway.go @@ -0,0 +1,62 @@ +package pushgateway + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Client is how interactions with a Push Gateway is done. +type Client interface { + // Notify sends a notification to the gateway at the given URL. + Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error +} + +type NotifyRequest struct { + Notification Notification `json:"notification"` // Required +} + +type NotifyResponse struct { + // Rejected is the list of device push keys that were rejected + // during the push. The caller should remove the push keys so they + // are not used again. + Rejected []string `json:"rejected"` // Required +} + +type Notification struct { + Content json.RawMessage `json:"content,omitempty"` + Counts *Counts `json:"counts,omitempty"` + Devices []*Device `json:"devices"` // Required + EventID string `json:"event_id,omitempty"` + ID string `json:"id,omitempty"` // Deprecated name for EventID. + Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest. + Prio Prio `json:"prio,omitempty"` + RoomAlias string `json:"room_alias,omitempty"` + RoomID string `json:"room_id,omitempty"` + RoomName string `json:"room_name,omitempty"` + Sender string `json:"sender,omitempty"` + SenderDisplayName string `json:"sender_display_name,omitempty"` + Type string `json:"type,omitempty"` + UserIsTarget bool `json:"user_is_target,omitempty"` +} + +type Counts struct { + MissedCalls int `json:"missed_calls,omitempty"` + Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it. +} + +type Device struct { + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` +} + +type Prio string + +const ( + HighPrio Prio = "high" + LowPrio Prio = "low" +) diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go new file mode 100644 index 000000000..c7b8cec83 --- /dev/null +++ b/internal/pushrules/action.go @@ -0,0 +1,102 @@ +package pushrules + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// An Action is (part of) an outcome of a rule. There are +// (unofficially) terminal actions, and modifier actions. +type Action struct { + // Kind is the type of action. Has custom encoding in JSON. + Kind ActionKind `json:"-"` + + // Tweak is the property to tweak. Has custom encoding in JSON. + Tweak TweakKey `json:"-"` + + // Value is some value interpreted according to Kind and Tweak. + Value interface{} `json:"value"` +} + +func (a *Action) MarshalJSON() ([]byte, error) { + if a.Tweak == UnknownTweak && a.Value == nil { + return json.Marshal(a.Kind) + } + + if a.Kind != SetTweakAction { + return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind) + } + + m := map[string]interface{}{ + string(a.Kind): a.Tweak, + } + if a.Value != nil { + m["value"] = a.Value + } + + return json.Marshal(m) +} + +func (a *Action) UnmarshalJSON(bs []byte) error { + if bytes.HasPrefix(bs, []byte("\"")) { + return json.Unmarshal(bs, &a.Kind) + } + + var raw struct { + SetTweak TweakKey `json:"set_tweak"` + Value interface{} `json:"value"` + } + if err := json.Unmarshal(bs, &raw); err != nil { + return err + } + if raw.SetTweak == UnknownTweak { + return fmt.Errorf("got unknown action JSON: %s", string(bs)) + } + a.Kind = SetTweakAction + a.Tweak = raw.SetTweak + if raw.Value != nil { + a.Value = raw.Value + } + + return nil +} + +// ActionKind is the primary discriminator for actions. +type ActionKind string + +const ( + UnknownAction ActionKind = "" + + // NotifyAction indicates the clients should show a notification. + NotifyAction ActionKind = "notify" + + // DontNotifyAction indicates the clients should not show a notification. + DontNotifyAction ActionKind = "dont_notify" + + // CoalesceAction tells the clients to show a notification, and + // tells both servers and clients that multiple events can be + // coalesced into a single notification. The behaviour is + // implementation-specific. + CoalesceAction ActionKind = "coalesce" + + // SetTweakAction uses the Tweak and Value fields to add a + // tweak. Multiple SetTweakAction can be provided in a rule, + // combined with NotifyAction or CoalesceAction. + SetTweakAction ActionKind = "set_tweak" +) + +// A TweakKey describes a property to be modified/tweaked for events +// that match the rule. +type TweakKey string + +const ( + UnknownTweak TweakKey = "" + + // SoundTweak describes which sound to play. Using "default" means + // "enable sound". + SoundTweak TweakKey = "sound" + + // HighlightTweak asks the clients to highlight the conversation. + HighlightTweak TweakKey = "highlight" +) diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go new file mode 100644 index 000000000..72db9c998 --- /dev/null +++ b/internal/pushrules/action_test.go @@ -0,0 +1,39 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionJSON(t *testing.T) { + tsts := []struct { + Want Action + }{ + {Action{Kind: NotifyAction}}, + {Action{Kind: DontNotifyAction}}, + {Action{Kind: CoalesceAction}}, + {Action{Kind: SetTweakAction}}, + + {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) { + bs, err := json.Marshal(&tst.Want) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got Action + if err := json.Unmarshal(bs, &got); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go new file mode 100644 index 000000000..2d9773c0f --- /dev/null +++ b/internal/pushrules/condition.go @@ -0,0 +1,49 @@ +package pushrules + +// A Condition dictates extra conditions for a matching rules. See +// ConditionKind. +type Condition struct { + // Kind is the primary discriminator for the condition + // type. Required. + Kind ConditionKind `json:"kind"` + + // Key indicates the dot-separated path of Event fields to + // match. Required for EventMatchCondition and + // SenderNotificationPermissionCondition. + Key string `json:"key,omitempty"` + + // Pattern indicates the value pattern that must match. Required + // for EventMatchCondition. + Pattern string `json:"pattern,omitempty"` + + // Is indicates the condition that must be fulfilled. Required for + // RoomMemberCountCondition. + Is string `json:"is,omitempty"` +} + +// ConditionKind represents a kind of condition. +// +// SPEC: Unrecognised conditions MUST NOT match any events, +// effectively making the push rule disabled. +type ConditionKind string + +const ( + UnknownCondition ConditionKind = "" + + // EventMatchCondition indicates the condition looks for a key + // path and matches a pattern. How paths that don't reference a + // simple value match against rules is implementation-specific. + EventMatchCondition ConditionKind = "event_match" + + // ContainsDisplayNameCondition indicates the current user's + // display name must be found in the content body. + ContainsDisplayNameCondition ConditionKind = "contains_display_name" + + // RoomMemberCountCondition matches a simple arithmetic comparison + // against the total number of members in a room. + RoomMemberCountCondition ConditionKind = "room_member_count" + + // SenderNotificationPermissionCondition compares power level for + // the sender in the event's room. + SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission" +) diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go new file mode 100644 index 000000000..996985514 --- /dev/null +++ b/internal/pushrules/default.go @@ -0,0 +1,23 @@ +package pushrules + +import ( + "github.com/matrix-org/gomatrixserverlib" +) + +// DefaultAccountRuleSets is the complete set of default push rules +// for an account. +func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { + return &AccountRuleSets{ + Global: *DefaultGlobalRuleSet(localpart, serverName), + } +} + +// DefaultGlobalRuleSet returns the default ruleset for a given (fully +// qualified) MXID. +func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { + return &RuleSet{ + Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), + Content: defaultContentRules(localpart), + Underride: defaultUnderrideRules, + } +} diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go new file mode 100644 index 000000000..158afd18b --- /dev/null +++ b/internal/pushrules/default_content.go @@ -0,0 +1,33 @@ +package pushrules + +func defaultContentRules(localpart string) []*Rule { + return []*Rule{ + mRuleContainsUserNameDefinition(localpart), + } +} + +const ( + MRuleContainsUserName = ".m.rule.contains_user_name" +) + +func mRuleContainsUserNameDefinition(localpart string) *Rule { + return &Rule{ + RuleID: MRuleContainsUserName, + Default: true, + Enabled: true, + Pattern: localpart, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } +} diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go new file mode 100644 index 000000000..6f66fd66a --- /dev/null +++ b/internal/pushrules/default_override.go @@ -0,0 +1,165 @@ +package pushrules + +func defaultOverrideRules(userID string) []*Rule { + return []*Rule{ + &mRuleMasterDefinition, + &mRuleSuppressNoticesDefinition, + mRuleInviteForMeDefinition(userID), + &mRuleMemberEventDefinition, + &mRuleContainsDisplayNameDefinition, + &mRuleTombstoneDefinition, + &mRuleRoomNotifDefinition, + } +} + +const ( + MRuleMaster = ".m.rule.master" + MRuleSuppressNotices = ".m.rule.suppress_notices" + MRuleInviteForMe = ".m.rule.invite_for_me" + MRuleMemberEvent = ".m.rule.member_event" + MRuleContainsDisplayName = ".m.rule.contains_display_name" + MRuleTombstone = ".m.rule.tombstone" + MRuleRoomNotif = ".m.rule.roomnotif" +) + +var ( + mRuleMasterDefinition = Rule{ + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Conditions: []*Condition{}, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleSuppressNoticesDefinition = Rule{ + RuleID: MRuleSuppressNotices, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.msgtype", + Pattern: "m.notice", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleMemberEventDefinition = Rule{ + RuleID: MRuleMemberEvent, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleContainsDisplayNameDefinition = Rule{ + RuleID: MRuleContainsDisplayName, + Default: true, + Enabled: true, + Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}}, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } + mRuleTombstoneDefinition = Rule{ + RuleID: MRuleTombstone, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.tombstone", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: "", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomNotifDefinition = Rule{ + RuleID: MRuleRoomNotif, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.body", + Pattern: "@room", + }, + { + Kind: SenderNotificationPermissionCondition, + Key: "room", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +) + +func mRuleInviteForMeDefinition(userID string) *Rule { + return &Rule{ + RuleID: MRuleInviteForMe, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + { + Kind: EventMatchCondition, + Key: "content.membership", + Pattern: "invite", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: userID, + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go new file mode 100644 index 000000000..de72bd526 --- /dev/null +++ b/internal/pushrules/default_underride.go @@ -0,0 +1,119 @@ +package pushrules + +const ( + MRuleCall = ".m.rule.call" + MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one" + MRuleRoomOneToOne = ".m.rule.room_one_to_one" + MRuleMessage = ".m.rule.message" + MRuleEncrypted = ".m.rule.encrypted" +) + +var defaultUnderrideRules = []*Rule{ + &mRuleCallDefinition, + &mRuleEncryptedRoomOneToOneDefinition, + &mRuleRoomOneToOneDefinition, + &mRuleMessageDefinition, + &mRuleEncryptedDefinition, +} + +var ( + mRuleCallDefinition = Rule{ + RuleID: MRuleCall, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.call.invite", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "ring", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleEncryptedRoomOneToOneDefinition = Rule{ + RuleID: MRuleEncryptedRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomOneToOneDefinition = Rule{ + RuleID: MRuleRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleMessageDefinition = Rule{ + RuleID: MRuleMessage, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } + mRuleEncryptedDefinition = Rule{ + RuleID: MRuleEncrypted, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } +) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go new file mode 100644 index 000000000..df22cb042 --- /dev/null +++ b/internal/pushrules/evaluate.go @@ -0,0 +1,165 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A RuleSetEvaluator encapsulates context to evaluate an event +// against a rule set. +type RuleSetEvaluator struct { + ec EvaluationContext + ruleSet []kindAndRules +} + +// An EvaluationContext gives a RuleSetEvaluator access to the +// environment, for rules that require that. +type EvaluationContext interface { + // UserDisplayName returns the current user's display name. + UserDisplayName() string + + // RoomMemberCount returns the number of members in the room of + // the current event. + RoomMemberCount() (int, error) + + // HasPowerLevel returns whether the user has at least the given + // power in the room of the current event. + HasPowerLevel(userID, levelKey string) (bool, error) +} + +// A kindAndRules is just here to simplify iteration of the (ordered) +// kinds of rules. +type kindAndRules struct { + Kind Kind + Rules []*Rule +} + +// NewRuleSetEvaluator creates a new evaluator for the given rule set. +func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator { + return &RuleSetEvaluator{ + ec: ec, + ruleSet: []kindAndRules{ + {OverrideKind, ruleSet.Override}, + {ContentKind, ruleSet.Content}, + {RoomKind, ruleSet.Room}, + {SenderKind, ruleSet.Sender}, + {UnderrideKind, ruleSet.Underride}, + }, + } +} + +// MatchEvent returns the first matching rule. Returns nil if there +// was no match rule. +func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { + // TODO: server-default rules have lower priority than user rules, + // but they are stored together with the user rules. It's a bit + // unclear what the specification (11.14.1.4 Predefined rules) + // means the ordering should be. + // + // The most reasonable interpretation is that default overrides + // still have lower priority than user content rules, so we + // iterate twice. + for _, rsat := range rse.ruleSet { + for _, defRules := range []bool{false, true} { + for _, rule := range rsat.Rules { + if rule.Default != defRules { + continue + } + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + if err != nil { + return nil, err + } + if ok { + return rule, nil + } + } + } + } + + // No matching rule. + return nil, nil +} + +func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + if !rule.Enabled { + return false, nil + } + + switch kind { + case OverrideKind, UnderrideKind: + for _, cond := range rule.Conditions { + ok, err := conditionMatches(cond, event, ec) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + return true, nil + + case ContentKind: + // TODO: "These configure behaviour for (unencrypted) messages + // that match certain patterns." - Does that mean "content.body"? + return patternMatches("content.body", rule.Pattern, event) + + case RoomKind: + return rule.RuleID == event.RoomID(), nil + + case SenderKind: + return rule.RuleID == event.Sender(), nil + + default: + return false, nil + } +} + +func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + switch cond.Kind { + case EventMatchCondition: + return patternMatches(cond.Key, cond.Pattern, event) + + case ContainsDisplayNameCondition: + return patternMatches("content.body", ec.UserDisplayName(), event) + + case RoomMemberCountCondition: + cmp, err := parseRoomMemberCountCondition(cond.Is) + if err != nil { + return false, fmt.Errorf("parsing room_member_count condition: %w", err) + } + n, err := ec.RoomMemberCount() + if err != nil { + return false, fmt.Errorf("RoomMemberCount failed: %w", err) + } + return cmp(n), nil + + case SenderNotificationPermissionCondition: + return ec.HasPowerLevel(event.Sender(), cond.Key) + + default: + return false, nil + } +} + +func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + re, err := globToRegexp(pattern) + if err != nil { + return false, err + } + + var eventMap map[string]interface{} + if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { + return false, fmt.Errorf("parsing event: %w", err) + } + v, err := lookupMapPath(strings.Split(key, "."), eventMap) + if err != nil { + // An unknown path is a benign error that shouldn't stop rule + // processing. It's just a non-match. + return false, nil + } + + return re.MatchString(fmt.Sprint(v)), nil +} diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go new file mode 100644 index 000000000..50e703365 --- /dev/null +++ b/internal/pushrules/evaluate_test.go @@ -0,0 +1,189 @@ +package pushrules + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/matrix-org/gomatrixserverlib" +) + +func TestRuleSetEvaluatorMatchEvent(t *testing.T) { + ev := mustEventFromJSON(t, `{}`) + defaultEnabled := &Rule{ + RuleID: ".default.enabled", + Default: true, + Enabled: true, + } + userEnabled := &Rule{ + RuleID: ".user.enabled", + Default: false, + Enabled: true, + } + userEnabled2 := &Rule{ + RuleID: ".user.enabled.2", + Default: false, + Enabled: true, + } + tsts := []struct { + Name string + RuleSet RuleSet + Want *Rule + }{ + {"empty", RuleSet{}, nil}, + {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled}, + {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled}, + {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled}, + {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled}, + {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled}, + {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled}, + {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + rse := NewRuleSetEvaluator(nil, &tst.RuleSet) + got, err := rse.MatchEvent(ev) + if err != nil { + t.Fatalf("MatchEvent failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("MatchEvent rule: +got -want:\n%s", diff) + } + }) + } +} + +func TestRuleMatches(t *testing.T) { + emptyRule := Rule{Enabled: true} + tsts := []struct { + Name string + Kind Kind + Rule Rule + EventJSON string + Want bool + }{ + {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, + {"emptyContent", ContentKind, emptyRule, `{}`, false}, + {"emptyRoom", RoomKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, + + {"disabled", OverrideKind, Rule{}, `{}`, false}, + + {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true}, + {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, + {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + if err != nil { + t.Fatalf("ruleMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("ruleMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestConditionMatches(t *testing.T) { + tsts := []struct { + Name string + Cond Condition + EventJSON string + Want bool + }{ + {"empty", Condition{}, `{}`, false}, + {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + + {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, + {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + + {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, + {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, + {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, + {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, + {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, + {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, + {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, + {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, + {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, + {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + + {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, + {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{}) + if err != nil { + t.Fatalf("conditionMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +type fakeEvaluationContext struct{} + +func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } +func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil } +func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { + return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +} + +func TestPatternMatches(t *testing.T) { + tsts := []struct { + Name string + Key string + Pattern string + EventJSON string + Want bool + }{ + {"empty", "", "", `{}`, false}, + + // Note that an empty pattern contains no wildcard characters, + // which implicitly means "*". + {"patternEmpty", "content", "", `{"content":{}}`, true}, + + {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, + {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, + {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true}, + {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true}, + {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON)) + if err != nil { + t.Fatalf("patternMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7) + if err != nil { + t.Fatal(err) + } + return ev +} diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go new file mode 100644 index 000000000..bbed1f95f --- /dev/null +++ b/internal/pushrules/pushrules.go @@ -0,0 +1,71 @@ +package pushrules + +// An AccountRuleSets carries the rule sets associated with an +// account. +type AccountRuleSets struct { + Global RuleSet `json:"global"` // Required +} + +// A RuleSet contains all the various push rules for an +// account. Listed in decreasing order of priority. +type RuleSet struct { + Override []*Rule `json:"override,omitempty"` + Content []*Rule `json:"content,omitempty"` + Room []*Rule `json:"room,omitempty"` + Sender []*Rule `json:"sender,omitempty"` + Underride []*Rule `json:"underride,omitempty"` +} + +// A Rule contains matchers, conditions and final actions. While +// evaluating, at most one rule is considered matching. +// +// Kind and scope are part of the push rules request/responses, but +// not of the core data model. +type Rule struct { + // RuleID is either a free identifier, or the sender's MXID for + // SenderKind. Required. + RuleID string `json:"rule_id"` + + // Default indicates whether this is a server-defined default, or + // a user-provided rule. Required. + // + // The server-default rules have the lowest priority. + Default bool `json:"default"` + + // Enabled allows the user to disable rules while keeping them + // around. Required. + Enabled bool `json:"enabled"` + + // Actions describe the desired outcome, should the rule + // match. Required. + Actions []*Action `json:"actions"` + + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions"` + + // Pattern is the body pattern to match for ContentKind. Required + // for that kind. The interpretation is the same as that of + // Condition.Pattern. + Pattern string `json:"pattern"` +} + +// Scope only has one valid value. See also AccountRuleSets. +type Scope string + +const ( + UnknownScope Scope = "" + GlobalScope Scope = "global" +) + +// Kind is the type of push rule. See also RuleSet. +type Kind string + +const ( + UnknownKind Kind = "" + OverrideKind Kind = "override" + ContentKind Kind = "content" + RoomKind Kind = "room" + SenderKind Kind = "sender" + UnderrideKind Kind = "underride" +) diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go new file mode 100644 index 000000000..027d35ef6 --- /dev/null +++ b/internal/pushrules/util.go @@ -0,0 +1,125 @@ +package pushrules + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// ActionsToTweaks converts a list of actions into a primary action +// kind and a tweaks map. Returns a nil map if it would have been +// empty. +func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { + var kind ActionKind + tweaks := map[string]interface{}{} + + for _, a := range as { + if a.Kind == SetTweakAction { + tweaks[string(a.Tweak)] = a.Value + continue + } + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } + + if len(tweaks) == 0 { + tweaks = nil + } + + return kind, tweaks, nil +} + +// BoolTweakOr returns the named tweak as a boolean, and returns `def` +// on failure. +func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool { + v, ok := tweaks[string(key)] + if !ok { + return def + } + b, ok := v.(bool) + if !ok { + return def + } + return b +} + +// globToRegexp converts a Matrix glob-style pattern to a Regular expression. +func globToRegexp(pattern string) (*regexp.Regexp, error) { + // TODO: It's unclear which glob characters are supported. The only + // place this is discussed is for the unrelated "m.policy.rule.*" + // events. Assuming, the same: /[*?]/ + if !strings.ContainsAny(pattern, "*?") { + pattern = "*" + pattern + "*" + } + + // The defined syntax doesn't allow escaping the glob wildcard + // characters, which makes this a straight-forward + // replace-after-quote. + pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta) + pattern = strings.Replace(pattern, "*", ".*", -1) + pattern = strings.Replace(pattern, "?", ".", -1) + return regexp.Compile("^(" + pattern + ")$") +} + +// globNonMetaRegexp are the characters that are not considered glob +// meta-characters (i.e. may need escaping). +var globNonMetaRegexp = regexp.MustCompile("[^*?]+") + +// lookupMapPath traverses a hierarchical map structure, like the one +// produced by json.Unmarshal, to return the leaf value. Traversing +// arrays/slices is not supported, only objects/maps. +func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) { + if len(path) == 0 { + return nil, fmt.Errorf("empty path") + } + + var v interface{} = m + for i, key := range path { + m, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v) + } + + v, ok = m[key] + if !ok { + return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], ".")) + } + } + + return v, nil +} + +// parseRoomMemberCountCondition parses a string like "2", "==2", "<2" +// into a function that checks if the argument to it fulfils the +// condition. +func parseRoomMemberCountCondition(s string) (func(int) bool, error) { + var b int + var cmp = func(a int) bool { return a == b } + switch { + case strings.HasPrefix(s, "<="): + cmp = func(a int) bool { return a <= b } + s = s[2:] + case strings.HasPrefix(s, ">="): + cmp = func(a int) bool { return a >= b } + s = s[2:] + case strings.HasPrefix(s, "<"): + cmp = func(a int) bool { return a < b } + s = s[1:] + case strings.HasPrefix(s, ">"): + cmp = func(a int) bool { return a > b } + s = s[1:] + case strings.HasPrefix(s, "=="): + // Same cmp as the default. + s = s[2:] + } + + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + b = int(v) + return cmp, nil +} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go new file mode 100644 index 000000000..a951c55a2 --- /dev/null +++ b/internal/pushrules/util_test.go @@ -0,0 +1,169 @@ +package pushrules + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionsToTweaks(t *testing.T) { + tsts := []struct { + Name string + Input []*Action + WantKind ActionKind + WantTweaks map[string]interface{} + }{ + {"empty", nil, UnknownAction, nil}, + {"zero", []*Action{{}}, UnknownAction, nil}, + {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, + {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, + { + "all", + []*Action{ + {Kind: CoalesceAction}, + {Kind: SetTweakAction, Tweak: HighlightTweak}, + {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}, + }, + CoalesceAction, + map[string]interface{}{"highlight": nil, "sound": "default"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + gotKind, gotTweaks, err := ActionsToTweaks(tst.Input) + if err != nil { + t.Fatalf("ActionsToTweaks failed: %v", err) + } + if gotKind != tst.WantKind { + t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind) + } + if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" { + t.Errorf("tweaks: +got -want:\n%s", diff) + } + }) + } +} + +func TestBoolTweakOr(t *testing.T) { + tsts := []struct { + Name string + Input map[string]interface{} + Def bool + Want bool + }{ + {"nil", nil, false, false}, + {"nilValue", map[string]interface{}{"highlight": nil}, true, true}, + {"false", map[string]interface{}{"highlight": false}, true, false}, + {"true", map[string]interface{}{"highlight": true}, false, true}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def) + if got != tst.Want { + t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestGlobToRegexp(t *testing.T) { + tsts := []struct { + Input string + Want string + }{ + {"", "^(.*.*)$"}, + {"a", "^(.*a.*)$"}, + {"a.b", "^(.*a\\.b.*)$"}, + {"a?b", "^(a.b)$"}, + {"a*b*", "^(a.*b.*)$"}, + {"a*b?", "^(a.*b.)$"}, + } + for _, tst := range tsts { + t.Run(tst.Want, func(t *testing.T) { + got, err := globToRegexp(tst.Input) + if err != nil { + t.Fatalf("globToRegexp failed: %v", err) + } + if got.String() != tst.Want { + t.Errorf("got %v, want %v", got.String(), tst.Want) + } + }) + } +} + +func TestLookupMapPath(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + Want interface{} + }{ + {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"}, + {[]string{"a"}, map[string]interface{}{"a": 42}, 42}, + {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err != nil { + t.Fatalf("lookupMapPath failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} + +func TestLookupMapPathInvalid(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + }{ + {nil, nil}, + {[]string{"a"}, nil}, + {[]string{"a", "b"}, map[string]interface{}{"a": "c"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err == nil { + t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got) + } + }) + } +} + +func TestParseRoomMemberCountCondition(t *testing.T) { + tsts := []struct { + Input string + WantTrue []int + WantFalse []int + }{ + {"1", []int{1}, []int{0, 2}}, + {"==1", []int{1}, []int{0, 2}}, + {"<1", []int{0}, []int{1, 2}}, + {"<=1", []int{0, 1}, []int{2}}, + {">1", []int{2}, []int{0, 1}}, + {">=42", []int{42, 43}, []int{41}}, + } + for _, tst := range tsts { + t.Run(tst.Input, func(t *testing.T) { + got, err := parseRoomMemberCountCondition(tst.Input) + if err != nil { + t.Fatalf("parseRoomMemberCountCondition failed: %v", err) + } + for _, v := range tst.WantTrue { + if !got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v) + } + } + for _, v := range tst.WantFalse { + if got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v) + } + } + }) + } +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go new file mode 100644 index 000000000..5d260f0b9 --- /dev/null +++ b/internal/pushrules/validate.go @@ -0,0 +1,85 @@ +package pushrules + +import ( + "fmt" + "regexp" +) + +// ValidateRule checks the rule for errors. These follow from Sytests +// and the specification. +func ValidateRule(kind Kind, rule *Rule) []error { + var errs []error + + if !validRuleIDRE.MatchString(rule.RuleID) { + errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) + } + + if len(rule.Actions) == 0 { + errs = append(errs, fmt.Errorf("missing actions")) + } + for _, action := range rule.Actions { + errs = append(errs, validateAction(action)...) + } + + for _, cond := range rule.Conditions { + errs = append(errs, validateCondition(cond)...) + } + + switch kind { + case OverrideKind, UnderrideKind: + // The empty list is allowed, but for JSON-encoding reasons, + // it must not be nil. + if rule.Conditions == nil { + errs = append(errs, fmt.Errorf("missing rule conditions")) + } + + case ContentKind: + if rule.Pattern == "" { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + + case RoomKind, SenderKind: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind)) + } + + return errs +} + +// validRuleIDRE is a regexp for valid IDs. +// +// TODO: the specification doesn't seem to say what the rule ID syntax +// is. A Sytest fails if it contains a backslash. +var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`) + +// validateAction returns issues with an Action. +func validateAction(action *Action) []error { + var errs []error + + switch action.Kind { + case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind)) + } + + return errs +} + +// validateCondition returns issues with a Condition. +func validateCondition(cond *Condition) []error { + var errs []error + + switch cond.Kind { + case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind)) + } + + return errs +} diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go new file mode 100644 index 000000000..b276eb551 --- /dev/null +++ b/internal/pushrules/validate_test.go @@ -0,0 +1,163 @@ +package pushrules + +import ( + "strings" + "testing" +) + +func TestValidateRuleNegatives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantErrString string + }{ + {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, + {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, + {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, + {"noActions", OverrideKind, Rule{}, "missing actions"}, + {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, + {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, + {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, + {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, + {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateRulePositives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantNoErrString string + }{ + {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"}, + {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"}, + {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"}, + {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"}, + {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"}, + {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"}, + {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"}, + {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateActionNegatives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantErrString string + }{ + {"emptyKind", Action{}, "invalid rule action kind"}, + {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateActionPositives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantNoErrString string + }{ + {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateConditionNegatives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantErrString string + }{ + {"emptyKind", Condition{}, "invalid rule condition kind"}, + {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateConditionPositives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantNoErrString string + }{ + {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 8d0d2dfa5..19483b268 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -163,6 +163,7 @@ type StatementList []struct { func (s StatementList) Prepare(db *sql.DB) (err error) { for _, statement := range s { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL) return } } diff --git a/q.sqlite b/q.sqlite new file mode 100644 index 000000000..b7d6268e2 Binary files /dev/null and b/q.sqlite differ diff --git a/run-sytest.sh b/run-sytest.sh new file mode 100755 index 000000000..47635fd12 --- /dev/null +++ b/run-sytest.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# +# Runs SyTest either from Docker Hub, or from ../sytest. If it's run +# locally, the Docker image is rebuilt first. +# +# Logs are stored in ../sytestout/logs. + +set -e +set -o pipefail + +main() { + local tag=buster + local base_image=debian:$tag + local runargs=() + + cd "$(dirname "$0")" + + if [ -d ../sytest ]; then + local tmpdir + tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)" + trap "rm -r '$tmpdir'" EXIT + + if [ -z "$DISABLE_BUILDING_SYTEST" ]; then + echo "Re-building ../sytest Docker images..." + + local status + ( + cd ../sytest + + docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" . + docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest . + ) &>"$tmpdir/buildlog" || status=$? + if (( status != 0 )); then + # Docker is very verbose, and we don't really care about + # building SyTest. So we accumulate and only output on + # failure. + cat "$tmpdir/buildlog" >&2 + return $status + fi + fi + + runargs+=( -v "$PWD/../sytest:/sytest:ro" ) + fi + if [ -n "$SYTEST_POSTGRES" ]; then + runargs+=( -e POSTGRES=1 ) + fi + + local sytestout=$PWD/../sytestout + mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg} + docker run \ + --rm \ + --name "sytest-dendrite-${LOGNAME}" \ + -e LOGS_USER=$(id -u) \ + -e LOGS_GROUP=$(id -g) \ + -v "$PWD:/src/:ro" \ + -v "$sytestout/logs:/logs/" \ + -v "$sytestout/cache/go-build:/root/.cache/go-build" \ + -v "$sytestout/cache/go-pkg:/gopath/pkg" \ + "${runargs[@]}" \ + matrixdotorg/sytest-dendrite:latest "$@" +} + +main "$@" diff --git a/setup/base/base.go b/setup/base/base.go index e39977541..ef3b2be29 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -30,6 +30,7 @@ import ( sentryhttp "github.com/getsentry/sentry-go/http" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" @@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { return f } +// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. +func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { + return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) +} + // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() userdb.Database { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 8f7611f0a..6467b7c82 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -205,6 +205,11 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 + pusher_database: + connection_string: file:pushserver.db + max_open_conns: 100 + max_idle_conns: 2 + conn_max_lifetime: -1 tracing: enabled: false jaeger: diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 1cb5eba18..570dc6030 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -13,6 +13,9 @@ type UserAPI struct { // The length of time an OpenID token is condidered valid in milliseconds OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` + // Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED! + PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"` + // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index 5810a2a91..3f07488f9 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -18,7 +18,10 @@ var ( OutputKeyChangeEvent = "OutputKeyChangeEvent" OutputTypingEvent = "OutputTypingEvent" OutputClientData = "OutputClientData" + OutputNotificationData = "OutputNotificationData" OutputReceiptEvent = "OutputReceiptEvent" + OutputStreamEvent = "OutputStreamEvent" + OutputReadUpdate = "OutputReadUpdate" ) var streams = []*nats.StreamConfig{ @@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, + { + Name: OutputNotificationData, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputStreamEvent, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputReadUpdate, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, } diff --git a/setup/monolith.go b/setup/monolith.go index 61125e4a9..7dbd2eeaa 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, - &m.Config.MSCs, + m.FederationAPI, m.UserAPI, m.KeyAPI, + m.ExtPublicRoomsProvider, &m.Config.MSCs, ) federationapi.AddPublicRoutes( ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index c3650085f..f01afce6d 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/internal/eventutil" @@ -24,21 +25,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -49,15 +55,18 @@ func NewOutputClientDataConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), - durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -100,8 +109,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) }).Panicf("could not save account data") } + if err = s.sendReadUpdate(ctx, userID, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": userID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) return true } + +func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { + if output.Type != "m.fully_read" || output.ReadMarker == nil { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + var fullyReadPos types.StreamPosition + if output.ReadMarker.Read != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if output.ReadMarker.FullyRead != "" { + if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil { + return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) + } + } + if readPos > 0 || fullyReadPos > 0 { + if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 392840ece..881583449 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/eduserver/api" @@ -24,21 +25,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -50,15 +56,18 @@ func NewOutputReceiptEventConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -92,8 +101,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms return true } + if err = s.sendReadUpdate(ctx, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": output.UserID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } + +func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error { + if output.Type != "m.read" { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + if output.EventID != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if readPos > 0 { + if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 15485bb35..159657f9f 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier + producer *producers.UserAPIStreamEventProducer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.RoomserverInternalAPI, + producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, + producer: producer, } } @@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } + if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { + log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) + sentry.CaptureException(err) + return err + } + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) sentry.CaptureException(err) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go new file mode 100644 index 000000000..a3b2dd53d --- /dev/null +++ b/syncapi/consumers/userapi.go @@ -0,0 +1,110 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// OutputNotificationDataConsumer consumes events that originated in +// the Push server. +type OutputNotificationDataConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider +} + +// NewOutputNotificationDataConsumer creates a new consumer. Call +// Start() to begin consuming. +func NewOutputNotificationDataConsumer( + process *process.ProcessContext, + cfg *config.SyncAPI, + js nats.JetStreamContext, + store storage.Database, + notifier *notifier.Notifier, + stream types.StreamProvider, +) *OutputNotificationDataConsumer { + s := &OutputNotificationDataConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + db: store, + notifier: notifier, + stream: stream, + } + return s +} + +// Start starts consumption. +func (s *OutputNotificationDataConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called when the Sync server receives a new event from +// the push server. It is not safe for this function to be called from +// multiple goroutines, or else the sync stream position may race and +// be incorrectly calculated. +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + userID := string(msg.Header.Get(jetstream.UserID)) + + // Parse out the event JSON + var data eventutil.NotificationData + if err := json.Unmarshal(msg.Data, &data); err != nil { + sentry.CaptureException(err) + log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure") + return true + } + + streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).WithError(err).Error("Could not save notification counts") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos}) + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + "streamPos": streamPos, + }).Trace("Received notification data from user API") + + return true +} diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index d853cc0e4..6a641e6f8 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite( n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) } +func (n *Notifier) OnNewNotificationData( + userID string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{userID}, nil, n.currPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index c6d3df7ee..60403d5d5 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) { go func() { pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { - t.Errorf("TestNewInviteEventForUser error: %w", err) + t.Errorf("TestNewInviteEventForUser error: %v", err) } mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go new file mode 100644 index 000000000..d56cab776 --- /dev/null +++ b/syncapi/producers/userapi_readupdate.go @@ -0,0 +1,62 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIReadProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + m.Header.Set(jetstream.RoomID, roomID) + + data := types.ReadUpdate{ + UserID: userID, + RoomID: roomID, + Read: readPos, + FullyRead: fullyReadPos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "read_pos": readPos, + "fully_read_pos": fullyReadPos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go new file mode 100644 index 000000000..2bbd19c0b --- /dev/null +++ b/syncapi/producers/userapi_streamevent.go @@ -0,0 +1,60 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIStreamEventProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.RoomID, roomID) + + data := types.StreamedEvent{ + Event: event, + StreamPosition: pos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "room_id": roomID, + "event_id": event.EventID(), + "event_type": event.Type(), + "stream_pos": pos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 126bc8658..e44766338 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" @@ -31,6 +32,7 @@ type Database interface { MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) + MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) @@ -138,6 +140,12 @@ type Database interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + + // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. + GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go new file mode 100644 index 000000000..f3fc4451f --- /dev/null +++ b/syncapi/storage/postgres/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 6f4e7749d..60fe5b54d 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + notificationData, err := NewPostgresNotificationDataTable(d.db) + if err != nil { + return nil, err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -110,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return &d, nil } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 819851b33..87d7c6df7 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -48,6 +48,7 @@ type Database struct { Filter tables.Filter Receipts tables.Receipts Memberships tables.Memberships + NotificationData tables.NotificationData } func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { @@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S return types.StreamPosition(id), nil } +func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.NotificationData.SelectMaxID(ctx) + if err != nil { + return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) + } + return types.StreamPosition(id), nil +} + func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) } @@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream return receipts, err } +func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount) + return err + }) + return +} + +func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) +} + func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go new file mode 100644 index 000000000..4b3f074db --- /dev/null +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id INTEGER PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 581ee6928..1b256f91a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -62,16 +62,19 @@ const selectEventsSQL = "" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectRecentEventsForSyncSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectMaxEventIDSQL = "" + @@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const deleteEventsForRoomSQL = "" + @@ -95,10 +99,12 @@ const selectContextEventSQL = "" + const selectContextBeforeEventSQL = "" + "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectContextAfterEventSQL = "" + "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters type outputRoomEventsStatements struct { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 706d43f81..f5ae9fdd7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er if err != nil { return err } + notificationData, err := NewSqliteNotificationDataTable(d.db) + if err != nil { + return err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -120,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return nil } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1d807ee6b..1ebb42651 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -19,6 +19,7 @@ import ( "database/sql" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -171,3 +172,9 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) } + +type NotificationData interface { + UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectMaxID(ctx context.Context) (int64, error) +} diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go new file mode 100644 index 000000000..8ba9e07ca --- /dev/null +++ b/syncapi/streams/stream_notificationdata.go @@ -0,0 +1,55 @@ +package streams + +import ( + "context" + + "github.com/matrix-org/dendrite/syncapi/types" +) + +type NotificationDataStreamProvider struct { + StreamProvider +} + +func (p *NotificationDataStreamProvider) Setup() { + p.StreamProvider.Setup() + + id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + if err != nil { + panic(err) + } + p.latest = id +} + +func (p *NotificationDataStreamProvider) CompleteSync( + ctx context.Context, + req *types.SyncRequest, +) types.StreamPosition { + return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) +} + +func (p *NotificationDataStreamProvider) IncrementalSync( + ctx context.Context, + req *types.SyncRequest, + from, to types.StreamPosition, +) types.StreamPosition { + // We want counts for all possible rooms, so always start from zero. + countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + if err != nil { + req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + return from + } + + // We're merely decorating existing rooms. Note that the Join map + // values are not pointers. + for roomID, jr := range req.Response.Rooms.Join { + counts := countsByRoom[roomID] + if counts == nil { + continue + } + + jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount + jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + req.Response.Rooms.Join[roomID] = jr + } + return to +} diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index c71095af6..17951acb4 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -12,13 +12,14 @@ import ( ) type Streams struct { - PDUStreamProvider types.StreamProvider - TypingStreamProvider types.StreamProvider - ReceiptStreamProvider types.StreamProvider - InviteStreamProvider types.StreamProvider - SendToDeviceStreamProvider types.StreamProvider - AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.StreamProvider + PDUStreamProvider types.StreamProvider + TypingStreamProvider types.StreamProvider + ReceiptStreamProvider types.StreamProvider + InviteStreamProvider types.StreamProvider + SendToDeviceStreamProvider types.StreamProvider + AccountDataStreamProvider types.StreamProvider + DeviceListStreamProvider types.StreamProvider + NotificationDataStreamProvider types.StreamProvider } func NewSyncStreamProviders( @@ -47,6 +48,9 @@ func NewSyncStreamProviders( StreamProvider: StreamProvider{DB: d}, userAPI: userAPI, }, + NotificationDataStreamProvider: &NotificationDataStreamProvider{ + StreamProvider: StreamProvider{DB: d}, + }, DeviceListStreamProvider: &DeviceListStreamProvider{ StreamProvider: StreamProvider{DB: d}, rsAPI: rsAPI, @@ -60,6 +64,7 @@ func NewSyncStreamProviders( streams.InviteStreamProvider.Setup() streams.SendToDeviceStreamProvider.Setup() streams.AccountDataStreamProvider.Setup() + streams.NotificationDataStreamProvider.Setup() streams.DeviceListStreamProvider.Setup() return streams @@ -67,12 +72,13 @@ func NewSyncStreamProviders( func (s *Streams) Latest(ctx context.Context) types.StreamingToken { return types.StreamingToken{ - PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), - TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), - InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), - SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), - AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), - DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), + PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), + TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), + InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), + SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), + AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), + NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx), + DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), } } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index ca35951a0..2c9920d18 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -189,7 +189,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) } } else { - syncReq.Log.Debugln("Responding to sync immediately") + syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") } if syncReq.Since.IsEmpty() { @@ -213,6 +213,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( syncReq.Context, syncReq, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( syncReq.Context, syncReq, ), @@ -244,6 +247,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Context, syncReq, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( syncReq.Context, syncReq, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 72462459c..cb9890ff7 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" @@ -64,6 +65,18 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) + userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + } + + userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + } + + _ = userAPIReadUpdateProducer + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), js, keyAPI, rsAPI, syncDB, notifier, @@ -75,7 +88,7 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( process, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, + streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -83,11 +96,19 @@ func AddPublicRoutes( clientConsumer := consumers.NewOutputClientDataConsumer( process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, + userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } + notificationConsumer := consumers.NewOutputNotificationDataConsumer( + process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, + ) + if err = notificationConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start notification data consumer") + } + typingConsumer := consumers.NewOutputTypingEventConsumer( process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider, ) @@ -104,6 +125,7 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, + userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c2e8ed01c..4150e6c98 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -95,13 +95,14 @@ const ( ) type StreamingToken struct { - PDUPosition StreamPosition - TypingPosition StreamPosition - ReceiptPosition StreamPosition - SendToDevicePosition StreamPosition - InvitePosition StreamPosition - AccountDataPosition StreamPosition - DeviceListPosition StreamPosition + PDUPosition StreamPosition + TypingPosition StreamPosition + ReceiptPosition StreamPosition + SendToDevicePosition StreamPosition + InvitePosition StreamPosition + AccountDataPosition StreamPosition + DeviceListPosition StreamPosition + NotificationDataPosition StreamPosition } // This will be used as a fallback by json.Marshal. @@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d_%d_%d_%d", + "s%d_%d_%d_%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, - t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, + t.InvitePosition, t.AccountDataPosition, + t.DeviceListPosition, t.NotificationDataPosition, ) return posStr } @@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.DeviceListPosition > other.DeviceListPosition: return true + case t.NotificationDataPosition > other.NotificationDataPosition: + return true } return false } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0 } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) { if other.DeviceListPosition > t.DeviceListPosition { t.DeviceListPosition = other.DeviceListPosition } + if other.NotificationDataPosition > t.NotificationDataPosition { + t.NotificationDataPosition = other.NotificationDataPosition + } } type TopologyToken struct { @@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions tok = strings.Split(tok, ".")[0] parts := strings.Split(tok[1:], "_") - var positions [7]StreamPosition + var positions [8]StreamPosition for i, p := range parts { if i >= len(positions) { break @@ -291,13 +298,14 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { positions[i] = StreamPosition(pos) } token = StreamingToken{ - PDUPosition: positions[0], - TypingPosition: positions[1], - ReceiptPosition: positions[2], - SendToDevicePosition: positions[3], - InvitePosition: positions[4], - AccountDataPosition: positions[5], - DeviceListPosition: positions[6], + PDUPosition: positions[0], + TypingPosition: positions[1], + ReceiptPosition: positions[2], + SendToDevicePosition: positions[3], + InvitePosition: positions[4], + AccountDataPosition: positions[5], + DeviceListPosition: positions[6], + NotificationDataPosition: positions[7], } return token, nil } @@ -383,6 +391,10 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` + UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` + } `json:"unread_notifications"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -462,3 +474,16 @@ type Peek struct { New bool Deleted bool } + +type ReadUpdate struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + Read StreamPosition `json:"read,omitempty"` + FullyRead StreamPosition `json:"fully_read,omitempty"` +} + +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. +type StreamedEvent struct { + Event *gomatrixserverlib.HeaderedEvent `json:"event"` + StreamPosition StreamPosition `json:"stream_position"` +} diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index cda178b37..ff78bfb9d 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -9,10 +9,10 @@ import ( func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), - "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), - "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(), + "s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(), + "s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass { diff --git a/sytest-blacklist b/sytest-blacklist index e8617dcdf..7f518b21a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -30,3 +30,9 @@ Local device key changes appear in /keys/changes Remove group category Remove group role +# Flakey +AS-ghosted users can use rooms themselves + +# Flakey, need additional investigation +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count diff --git a/sytest-whitelist b/sytest-whitelist index 3e38176f4..602f86465 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -339,17 +339,17 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device -Can send a message directly to a device using PUT /sendToDevice -Can recv a device message using /sync -Can recv device messages until they are acknowledged -Device messages with the same txn_id are deduplicated -Device messages wake up /sync -Can recv device messages over federation -Device messages over federation wake up /sync -Can send messages with a wildcard device id -Can send messages with a wildcard device id to two devices -Wildcard device messages wake up /sync -Wildcard device messages over federation wake up /sync +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +Can recv device messages over federation +Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +Wildcard device messages over federation wake up /sync Can send a to-device message to two users which both receive it using /sync User can create and send/receive messages in a room with version 6 local user can join room with version 6 @@ -477,7 +477,7 @@ Federation key API can act as a notary server via a GET request Inbound /make_join rejects attempts to join rooms where all users have left Inbound federation rejects invites which include invalid JSON for room version 6 Inbound federation rejects invite rejections which include invalid JSON for room version 6 -GET /capabilities is present and well formed for registered user +GET /capabilities is present and well formed for registered user m.room.history_visibility == "joined" allows/forbids appropriately for Guest users m.room.history_visibility == "joined" allows/forbids appropriately for Real users POST rejects invalid utf-8 in JSON @@ -588,6 +588,59 @@ User can invite remote user to room with version 9 Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 +Pushers created with a different access token are deleted on password change +Pushers created with a the same access token are not deleted on password change +Can fetch a user's pushers +Can add global push rule for room +Can add global push rule for sender +Can add global push rule for content +Can add global push rule for override +Can add global push rule for underride +Can add global push rule for content +New rules appear before old rules by default +Can add global push rule before an existing rule +Can add global push rule after an existing rule +Can delete a push rule +Can disable a push rule +Adding the same push rule twice is idempotent +Can change the actions of default rules +Can change the actions of a user specified rule +Adding a push rule wakes up an incremental /sync +Disabling a push rule wakes up an incremental /sync +Enabling a push rule wakes up an incremental /sync +Setting actions for a push rule wakes up an incremental /sync +Can enable/disable default rules +Trying to add push rule with missing template fails with 400 +Trying to add push rule with missing rule_id fails with 400 +Trying to add push rule with empty rule_id fails with 400 +Trying to add push rule with invalid template fails with 400 +Trying to add push rule with rule_id with slashes fails with 400 +Trying to add push rule with override rule without conditions fails with 400 +Trying to add push rule with underride rule without conditions fails with 400 +Trying to add push rule with condition without kind fails with 400 +Trying to add push rule with content rule without pattern fails with 400 +Trying to add push rule with no actions fails with 400 +Trying to add push rule with invalid action fails with 400 +Trying to add push rule with invalid attr fails with 400 +Trying to add push rule with invalid value for enabled fails with 400 +Trying to get push rules with no trailing slash fails with 400 +Trying to get push rules with scope without trailing slash fails with 400 +Trying to get push rules with template without tailing slash fails with 400 +Trying to get push rules with unknown scope fails with 400 +Trying to get push rules with unknown template fails with 400 +Trying to get push rules with unknown attribute fails with 400 +Getting push rules doesn't corrupt the cache SYN-390 +Test that a message is pushed +Invites are pushed +Rooms with names are correctly named in pushes +Rooms with canonical alias are correctly named in pushed +Rooms with many users are correctly pushed +Don't get pushed for rooms you've muted +Rejected events are not pushed +Test that rejected pushers are removed. +Notifications can be viewed with GET /notifications +Trying to add push rule with no scope fails with 400 +Trying to add push rule with invalid scope fails with 400 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation @@ -607,4 +660,4 @@ registration accepts non-ascii passwords registration with inhibit_login inhibits login The operation must be consistent through an interactive authentication session Multiple calls to /sync should not cause 500 errors - +/context/ with lazy_load_members filter works diff --git a/userapi/api/api.go b/userapi/api/api.go index 2be662e55..e9cdbe01c 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" ) // UserInternalAPI is the internal API for information about users and devices. @@ -28,6 +29,7 @@ type UserInternalAPI interface { LoginTokenInternalAPI InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error @@ -37,6 +39,10 @@ type UserInternalAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error + PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error + PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -45,6 +51,9 @@ type UserInternalAPI interface { QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error + QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error + QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error } type PerformKeyBackupRequest struct { @@ -424,3 +433,77 @@ const ( // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 ) + +type QueryPushersRequest struct { + Localpart string +} + +type QueryPushersResponse struct { + Pushers []Pusher `json:"pushers"` +} + +type PerformPusherSetRequest struct { + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + Append bool `json:"append"` +} + +type PerformPusherDeletionRequest struct { + Localpart string + SessionID int64 +} + +// Pusher represents a push notification subscriber +type Pusher struct { + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` +} + +type PusherKind string + +const ( + EmailKind PusherKind = "email" + HTTPKind PusherKind = "http" +) + +type PerformPushRulesPutRequest struct { + UserID string `json:"user_id"` + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryPushRulesRequest struct { + UserID string `json:"user_id"` +} + +type QueryPushRulesResponse struct { + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryNotificationsRequest struct { + Localpart string `json:"localpart"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` +} + +type QueryNotificationsResponse struct { + NextToken string `json:"next_token"` + Notifications []*Notification `json:"notifications"` // Required. +} + +type Notification struct { + Actions []*pushrules.Action `json:"actions"` // Required. + Event gomatrixserverlib.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS gomatrixserverlib.Timestamp `json:"ts"` // Required. +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index aa069f40b..9334f4455 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { + err := t.Impl.PerformPusherSet(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { + err := t.Impl.PerformPusherDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { + err := t.Impl.PerformPushRulesPut(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) + return err +} func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) @@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { + err := t.Impl.QueryPushers(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { + err := t.Impl.QueryPushRules(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { + err := t.Impl.QueryNotifications(ctx, req, res) + util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) + return err +} func js(thing interface{}) string { b, err := json.Marshal(thing) diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go new file mode 100644 index 000000000..2e58020b4 --- /dev/null +++ b/userapi/consumers/syncapi_readupdate.go @@ -0,0 +1,136 @@ +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputReadUpdateConsumer struct { + ctx context.Context + cfg *config.UserAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + pgClient pushgateway.Client + ServerName gomatrixserverlib.ServerName + topic string + userAPI uapi.UserInternalAPI + syncProducer *producers.SyncAPI +} + +func NewOutputReadUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI uapi.UserInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputReadUpdateConsumer { + return &OutputReadUpdateConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + ServerName: cfg.Matrix.ServerName, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + pgClient: pgClient, + userAPI: userAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputReadUpdateConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var read types.ReadUpdate + if err := json.Unmarshal(msg.Data, &read); err != nil { + log.WithError(err).Error("userapi clientapi consumer: message parse failure") + return true + } + if read.FullyRead == 0 && read.Read == 0 { + return true + } + + userID := string(msg.Header.Get(jetstream.UserID)) + roomID := string(msg.Header.Get(jetstream.RoomID)) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.ServerName { + log.Error("userapi clientapi consumer: not a local user") + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + log.Tracef("Received read update from sync API: %#v", read) + + if read.Read > 0 { + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if updated { + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + } + } + + if read.FullyRead > 0 { + deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) + if err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") + return false + } + + if deleted { + if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") + return false + } + + if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") + return false + } + } + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go new file mode 100644 index 000000000..110813274 --- /dev/null +++ b/userapi/consumers/syncapi_streamevent.go @@ -0,0 +1,588 @@ +package consumers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputStreamEventConsumer struct { + ctx context.Context + cfg *config.UserAPI + userAPI api.UserInternalAPI + rsAPI rsapi.RoomserverInternalAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + topic string + pgClient pushgateway.Client + syncProducer *producers.SyncAPI +} + +func NewOutputStreamEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI api.UserInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputStreamEventConsumer { + return &OutputStreamEventConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + pgClient: pgClient, + userAPI: userAPI, + rsAPI: rsAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputStreamEventConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output types.StreamedEvent + output.Event = &gomatrixserverlib.HeaderedEvent{} + if err := json.Unmarshal(msg.Data, &output); err != nil { + log.WithError(err).Errorf("userapi consumer: message parse failure") + return true + } + if output.Event.Event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + "event_type": output.Event.Type(), + "stream_pos": output.StreamPosition, + }).Tracef("Received message from sync API: %#v", output) + + if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + }).WithError(err).Errorf("userapi consumer: process room event failure") + } + + return true +} + +func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + if err != nil { + return fmt.Errorf("s.localRoomMembers: %w", err) + } + + if event.Type() == gomatrixserverlib.MRoomMember { + cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + var member *localMembership + member, err = newLocalMembership(&cevent) + if err != nil { + return fmt.Errorf("newLocalMembership: %w", err) + } + if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + // localRoomMembers only adds joined members. An invite + // should also be pushed to the target user. + members = append(members, member) + } + } + + // TODO: run in parallel with localRoomMembers. + roomName, err := s.roomName(ctx, event) + if err != nil { + return fmt.Errorf("s.roomName: %w", err) + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "num_members": len(members), + "room_size": roomSize, + }).Tracef("Notifying members") + + // Notification.UserIsTarget is a per-member field, so we + // cannot group all users in a single request. + // + // TODO: does it have to be set? It's not required, and + // removing it means we can send all notifications to + // e.g. Element's Push gateway in one go. + for _, mem := range members { + if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + log.WithFields(log.Fields{ + "localpart": mem.Localpart, + }).WithError(err).Debugf("Unable to push to local user") + continue + } + } + + return nil +} + +type localMembership struct { + gomatrixserverlib.MemberContent + UserID string + Localpart string + Domain gomatrixserverlib.ServerName +} + +func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { + if event.StateKey == nil { + return nil, fmt.Errorf("missing state_key") + } + + var member localMembership + if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { + return nil, err + } + + localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + return nil, err + } + + member.UserID = *event.StateKey + member.Localpart = localpart + member.Domain = domain + return &member, nil +} + +// localRoomMembers fetches the current local members of a room, and +// the total number of members. +func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + req := &rsapi.QueryMembershipsForRoomRequest{ + RoomID: roomID, + JoinedOnly: true, + } + var res rsapi.QueryMembershipsForRoomResponse + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { + return nil, 0, err + } + + var members []*localMembership + var ntotal int + for _, event := range res.JoinEvents { + member, err := newLocalMembership(&event) + if err != nil { + log.WithError(err).Errorf("Parsing MemberContent") + continue + } + if member.Membership != gomatrixserverlib.Join { + continue + } + if member.Domain != s.cfg.Matrix.ServerName { + continue + } + + ntotal++ + members = append(members, member) + } + + return members, ntotal, nil +} + +// roomName returns the name in the event (if type==m.room.name), or +// looks it up in roomserver. If there is no name, +// m.room.canonical_alias is consulted. Returns an empty string if the +// room has no name. +func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { + if event.Type() == gomatrixserverlib.MRoomName { + name, err := unmarshalRoomName(event) + if err != nil { + return "", err + } + + if name != "" { + return name, nil + } + } + + req := &rsapi.QueryCurrentStateRequest{ + RoomID: event.RoomID(), + StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, + } + var res rsapi.QueryCurrentStateResponse + + if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { + return "", nil + } + + if eventS := res.StateEvents[roomNameTuple]; eventS != nil { + return unmarshalRoomName(eventS) + } + + if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + alias, err := unmarshalCanonicalAlias(event) + if err != nil { + return "", err + } + + if alias != "" { + return alias, nil + } + } + + if event = res.StateEvents[canonicalAliasTuple]; event != nil { + return unmarshalCanonicalAlias(event) + } + + return "", nil +} + +var ( + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} +) + +func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var nc eventutil.NameContent + if err := json.Unmarshal(event.Content(), &nc); err != nil { + return "", fmt.Errorf("unmarshaling NameContent: %w", err) + } + + return nc.Name, nil +} + +func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var cac eventutil.CanonicalAliasContent + if err := json.Unmarshal(event.Content(), &cac); err != nil { + return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) + } + + return cac.Alias, nil +} + +// notifyLocal finds the right push actions for a local user, given an event. +func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { + actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) + if err != nil { + return err + } + a, tweaks, err := pushrules.ActionsToTweaks(actions) + if err != nil { + return err + } + // TODO: support coalescing. + if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + }).Tracef("Push rule evaluation rejected the event") + return nil + } + + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + if err != nil { + return err + } + + n := &api.Notification{ + Actions: actions, + // UNSPEC: the spec doesn't say this is a ClientEvent, but the + // fields seem to match. room_id should be missing, which + // matches the behaviour of FormatSync. + Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + // TODO: this is per-device, but it's not part of the primary + // key. So inserting one notification per profile tag doesn't + // make sense. What is this supposed to be? Sytests require it + // to "work", but they only use a single device. + ProfileTag: profileTag, + RoomID: event.RoomID(), + TS: gomatrixserverlib.AsTimestamp(time.Now()), + } + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + return err + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + return err + } + + // We do this after InsertNotification. Thus, this should always return >=1. + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "num_urls": len(devicesByURLAndFormat), + "num_unread": userNumUnreadNotifs, + }).Tracef("Notifying single member") + + // Push gateways are out of our control, and we cannot risk + // looking up the server on a misbehaving push gateway. Each user + // receives a goroutine now that all internal API calls have been + // made. + // + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var rejected []*pushgateway.Device + for url, fmts := range devicesByURLAndFormat { + for format, devices := range fmts { + // TODO: support "email". + if !strings.HasPrefix(url, "http") { + continue + } + + // UNSPEC: the specification suggests there can be + // more than one device per request. There is at least + // one Sytest that expects one HTTP request per + // device, rather than per URL. For now, we must + // notify each one separately. + for _, dev := range devices { + rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) + if err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "localpart": mem.Localpart, + }).WithError(err).Errorf("Unable to notify HTTP pusher") + continue + } + rejected = append(rejected, rej...) + } + } + } + + if len(rejected) > 0 { + s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + } + }() + + return nil +} + +// evaluatePushRules fetches and evaluates the push rules of a local +// user. Returns actions (including dont_notify). +func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + if event.Sender() == mem.UserID { + // SPEC: Homeservers MUST NOT notify the Push Gateway for + // events that the user has sent themselves. + return nil, nil + } + + var res api.QueryPushRulesResponse + if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil { + return nil, err + } + + ec := &ruleSetEvalContext{ + ctx: ctx, + rsAPI: s.rsAPI, + mem: mem, + roomID: event.RoomID(), + roomSize: roomSize, + } + eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global) + rule, err := eval.MatchEvent(event.Event) + if err != nil { + return nil, err + } + if rule == nil { + // SPEC: If no rules match an event, the homeserver MUST NOT + // notify the Push Gateway for that event. + return nil, err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "rule_id": rule.RuleID, + }).Tracef("Matched a push rule") + + return rule.Actions, nil +} + +type ruleSetEvalContext struct { + ctx context.Context + rsAPI rsapi.RoomserverInternalAPI + mem *localMembership + roomID string + roomSize int +} + +func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } + +func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } + +func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { + req := &rsapi.QueryLatestEventsAndStateRequest{ + RoomID: rse.roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels}, + }, + } + var res rsapi.QueryLatestEventsAndStateResponse + if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { + return false, err + } + for _, ev := range res.StateEvents { + if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + continue + } + + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + if err != nil { + return false, err + } + return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + } + return true, nil +} + +// localPushDevices pushes to the configured devices of a local +// user. The map keys are [url][format]. +func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) + if err != nil { + return nil, "", err + } + + var profileTag string + devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) + for _, pusherDevice := range pusherDevices { + if profileTag == "" { + profileTag = pusherDevice.Pusher.ProfileTag + } + + url := pusherDevice.URL + if devicesByURL[url] == nil { + devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) + } + devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) + } + + return devicesByURL, profileTag, nil +} + +// notifyHTTP performs a notificatation to a Push Gateway. +func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { + logger := log.WithFields(log.Fields{ + "event_id": event.EventID(), + "url": url, + "localpart": localpart, + "num_devices": len(devices), + }) + + var req pushgateway.NotifyRequest + switch format { + case "event_id_only": + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{}, + Devices: devices, + EventID: event.EventID(), + RoomID: event.RoomID(), + }, + } + + default: + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Content: event.Content(), + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + ID: event.EventID(), + RoomID: event.RoomID(), + RoomName: roomName, + Sender: event.Sender(), + Type: event.Type(), + }, + } + if mem, err := event.Membership(); err == nil { + req.Notification.Membership = mem + } + if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + req.Notification.UserIsTarget = true + } + } + + logger.Debugf("Notifying push gateway %s", url) + var res pushgateway.NotifyResponse + if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { + logger.WithError(err).Errorf("Failed to notify push gateway %s", url) + return nil, err + } + logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + + if len(res.Rejected) == 0 { + return nil, nil + } + + devMap := make(map[string]*pushgateway.Device, len(devices)) + for _, d := range devices { + devMap[d.PushKey] = d + } + rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) + for _, pushKey := range res.Rejected { + d := devMap[pushKey] + if d != nil { + rejected = append(rejected, d) + } + } + + return rejected, nil +} + +// deleteRejectedPushers deletes the pushers associated with the given devices. +func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": devices[0].AppID, + "num_devices": len(devices), + }).Warnf("Deleting pushers rejected by the HTTP push gateway") + + for _, d := range devices { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + }).WithError(err).Errorf("Unable to delete rejected pusher") + } + } +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f54cc6137..7a42fc605 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -27,16 +29,22 @@ import ( "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type UserInternalAPI struct { - DB storage.Database - ServerName gomatrixserverlib.ServerName + DB storage.Database + SyncProducer *producers.SyncAPI + + DisableTLSValidation bool + ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.KeyInternalAPI @@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB } res.Keys = result } + +func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + if req.Limit == 0 || req.Limit > 1000 { + req.Limit = 1000 + } + + var fromID int64 + var err error + if req.From != "" { + fromID, err = strconv.ParseInt(req.From, 10, 64) + if err != nil { + return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) + } + } + var filter tables.NotificationFilter = tables.AllNotifications + if req.Only == "highlight" { + filter = tables.HighlightNotifications + } + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + if err != nil { + return err + } + if notifs == nil { + // This ensures empty is JSON-encoded as [] instead of null. + notifs = []*api.Notification{} + } + res.Notifications = notifs + if lastID >= 0 { + res.NextToken = strconv.FormatInt(lastID+1, 10) + } + return nil +} + +func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.Pusher.PushKey, + "display_name": req.Pusher.AppDisplayName, + }).Info("PerformPusherCreation") + if !req.Append { + err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) + if err != nil { + return err + } + } + if req.Pusher.Kind == "" { + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + } + if req.Pusher.PushKeyTS == 0 { + req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + } + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) +} + +func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + pushers, err := a.DB.GetPushers(ctx, req.Localpart) + if err != nil { + return err + } + for i := range pushers { + logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) + if pushers[i].SessionID != req.SessionID { + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + if err != nil { + return err + } + } + } + return nil +} + +func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + var err error + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + return err +} + +func (a *UserInternalAPI) PerformPushRulesPut( + ctx context.Context, + req *api.PerformPushRulesPutRequest, + _ *struct{}, +) error { + bs, err := json.Marshal(&req.RuleSets) + if err != nil { + return err + } + userReq := api.InputAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + AccountData: json.RawMessage(bs), + } + var userRes api.InputAccountDataResponse // empty + if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + + if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed") + } + + return nil +} + +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + userReq := api.QueryAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + } + var userRes api.QueryAccountDataResponse + if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType] + if ok { + // Legacy Dendrite users will have completely empty push rules, so we should + // detect that situation and set some defaults. + var rules struct { + G struct { + Content []json.RawMessage `json:"content"` + Override []json.RawMessage `json:"override"` + Room []json.RawMessage `json:"room"` + Sender []json.RawMessage `json:"sender"` + Underride []json.RawMessage `json:"underride"` + } `json:"global"` + } + if err := json.Unmarshal([]byte(bs), &rules); err == nil { + count := len(rules.G.Content) + len(rules.G.Override) + + len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride) + ok = count > 0 + } + } + if !ok { + // If we didn't find any default push rules then we should just generate some + // fresh ones. + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + } + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return fmt.Errorf("failed to marshal default push rules: %w", err) + } + if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil { + return fmt.Errorf("failed to save default push rules: %w", err) + } + res.RuleSets = pushRuleSets + return nil + } + var data pushrules.AccountRuleSets + if err := json.Unmarshal([]byte(bs), &data); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed") + return err + } + res.RuleSets = &data + return nil +} + +const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1599d4639..8ec649ad0 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,6 +37,9 @@ const ( PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformPusherSetPath = "/pushserver/performPusherSet" + PerformPusherDeletionPath = "/pushserver/performPusherDeletion" + PerformPushRulesPutPath = "/pushserver/performPushRulesPut" QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" @@ -46,6 +49,9 @@ const ( QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPushersPath = "/pushserver/queryPushers" + QueryPushRulesPath = "/pushserver/queryPushRules" + QueryNotificationsPath = "/pushserver/queryNotifications" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query res.Error = err.Error() } } + +func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") + defer span.Finish() + + return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +} + +func (h *httpUserInternalAPI) PerformPusherSet( + ctx context.Context, + request *api.PerformPusherSetRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherSetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") + defer span.Finish() + + apiURL := h.apiURL + QueryPushersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformPushRulesPut( + ctx context.Context, + request *api.PerformPushRulesPutRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") + defer span.Finish() + + apiURL := h.apiURL + PerformPushRulesPutPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") + defer span.Finish() + + apiURL := h.apiURL + QueryPushRulesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d00ee042c..526f99575 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryNotificationsPath, + httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { + var request api.QueryNotificationsRequest + var response api.QueryNotificationsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPusherSetPath, + httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherSetRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformPusherDeletionPath, + httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherDeletionRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushersPath, + httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { + request := api.QueryPushersRequest{} + response := api.QueryPushersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPushRulesPutPath, + httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { + request := api.PerformPushRulesPutRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushRulesPath, + httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { + request := api.QueryPushRulesRequest{} + response := api.QueryPushRulesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go new file mode 100644 index 000000000..4a206f333 --- /dev/null +++ b/userapi/producers/syncapi.go @@ -0,0 +1,104 @@ +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type JetStreamPublisher interface { + PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) +} + +// SyncAPI produces messages for the Sync API server to consume. +type SyncAPI struct { + db storage.Database + producer JetStreamPublisher + clientDataTopic string + notificationDataTopic string +} + +func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { + return &SyncAPI{ + db: db, + producer: js, + clientDataTopic: clientDataTopic, + notificationDataTopic: notificationDataTopic, + } +} + +// SendAccountData sends account data to the Sync API server. +func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error { + m := &nats.Msg{ + Subject: p.clientDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(eventutil.AccountData{ + RoomID: roomID, + Type: dataType, + }) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} + +// GetAndSendNotificationData reads the database and sends data about unread +// notifications to the Sync API server. +func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + if err != nil { + return err + } + + return p.sendNotificationData(userID, &eventutil.NotificationData{ + RoomID: roomID, + UnreadHighlightCount: int(nhighlight), + UnreadNotificationCount: int(ntotal), + }) +} + +// sendNotificationData sends data about unread notifications to the Sync API server. +func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error { + m := &nats.Msg{ + Subject: p.notificationDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index a131dac47..6d22fea9d 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type Database interface { @@ -89,6 +90,18 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + + UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error + GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + RemovePushers(ctx context.Context, appid, pushkey string) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go new file mode 100644 index 000000000..7bcc0f9cd --- /dev/null +++ b/userapi/storage/postgres/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id BIGSERIAL PRIMARY KEY, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go new file mode 100644 index 000000000..670dc916f --- /dev/null +++ b/userapi/storage/postgres/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id BIGSERIAL PRIMARY KEY, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index ac5c59b81..c74a999f4 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) } + pusherTable, err := NewPostgresPusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewPostgresNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 5f1f95005..a58974b41 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -47,6 +48,8 @@ type Database struct { KeyBackupVersions tables.KeyBackupVersionTable Devices tables.DevicesTable LoginTokens tables.LoginTokenTable + Notifications tables.NotificationTable + Pushers tables.PusherTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -160,15 +163,12 @@ func (d *Database) createAccount( if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { return nil, err } return account, nil @@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.LoginTokens.SelectLoginToken(ctx, token) } + +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +} + +func (d *Database) UpsertPusher( + ctx context.Context, p api.Pusher, localpart string, +) error { + data, err := json.Marshal(p.Data) + if err != nil { + return err + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Pushers.InsertPusher( + ctx, txn, + p.SessionID, + p.PushKey, + p.PushKeyTS, + p.Kind, + p.AppID, + p.AppDisplayName, + p.DeviceDisplayName, + p.ProfileTag, + p.Language, + string(data), + localpart) + }) +} + +// GetPushers returns the pushers matching the given localpart. +func (d *Database) GetPushers( + ctx context.Context, localpart string, +) ([]api.Pusher, error) { + return d.Pushers.SelectPushers(ctx, nil, localpart) +} + +// RemovePusher deletes one pusher +// Invoked when `append` is true and `kind` is null in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePusher( + ctx context.Context, appid, pushkey, localpart string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + if err == sql.ErrNoRows { + return nil + } + return err + }) +} + +// RemovePushers deletes all pushers that match given App Id and Push Key pair. +// Invoked when `append` parameter is false in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePushers( + ctx context.Context, appid, pushkey string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) + }) +} diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go new file mode 100644 index 000000000..fcfb1aadc --- /dev/null +++ b/userapi/storage/sqlite3/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go new file mode 100644 index 000000000..e718792e1 --- /dev/null +++ b/userapi/storage/sqlite3/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 98c244977..b5bb96c42 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) } + pusherTable, err := NewSQLitePusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewSQLiteNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 12939ced5..815e51193 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -93,3 +94,42 @@ type ThreePIDTable interface { InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } + +type PusherTable interface { + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error +} + +type NotificationTable interface { + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +) diff --git a/userapi/userapi.go b/userapi/userapi.go index 4a5793abb..2382e9512 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,11 +18,17 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) @@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI, + appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(db, cfg, appServices, keyAPI) -} + js := jetstream.Prepare(&cfg.Matrix.JetStream) -func newInternalAPI( - db storage.Database, - cfg *config.UserAPI, - appServices []config.ApplicationService, - keyAPI keyapi.KeyInternalAPI, -) api.UserInternalAPI { - return &internal.UserInternalAPI{ - DB: db, - ServerName: cfg.Matrix.ServerName, - AppServices: appServices, - KeyAPI: keyAPI, + syncProducer := producers.NewSyncAPI( + db, js, + // TODO: user API should handle syncs for account data. Right now, + // it's handled by clientapi, and hence uses its topic. When user + // API handles it for all account data, we can remove it from + // here. + cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + ) + + userAPI := &internal.UserInternalAPI{ + DB: db, + SyncProducer: syncProducer, + ServerName: cfg.Matrix.ServerName, + AppServices: appServices, + KeyAPI: keyAPI, + DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, } + + readConsumer := consumers.NewOutputReadUpdateConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + ) + if err := readConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API read update consumer") + } + + eventConsumer := consumers.NewOutputStreamEventConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer, + ) + if err := eventConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API streamed event consumer") + } + + return userAPI } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 4214c07f7..25319c4bf 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s }, } - return newInternalAPI(accountDB, cfg, nil, nil), accountDB + return &internal.UserInternalAPI{ + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB } func TestQueryProfile(t *testing.T) { diff --git a/userapi/util/devices.go b/userapi/util/devices.go new file mode 100644 index 000000000..cbf3bd28f --- /dev/null +++ b/userapi/util/devices.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + log "github.com/sirupsen/logrus" +) + +type PusherDevice struct { + Device pushgateway.Device + Pusher *api.Pusher + URL string + Format string +} + +// GetPushDevices pushes to the configured devices of a local user. +func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart) + if err != nil { + return nil, err + } + + devices := make([]*PusherDevice, 0, len(pushers)) + for _, pusher := range pushers { + var url, format string + data := pusher.Data + switch pusher.Kind { + case api.EmailKind: + url = "mailto:" + + case api.HTTPKind: + // TODO: The spec says only event_id_only is supported, + // but Sytests assume "" means "full notification". + fmtIface := pusher.Data["format"] + var ok bool + format, ok = fmtIface.(string) + if ok && format != "event_id_only" { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("Only data.format event_id_only or empty is supported") + continue + } + + urlIface := pusher.Data["url"] + url, ok = urlIface.(string) + if !ok { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("No data.url configured for HTTP Pusher") + continue + } + data = mapWithout(data, "url") + + default: + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + "kind": pusher.Kind, + }).Errorf("Unhandled pusher kind") + continue + } + + devices = append(devices, &PusherDevice{ + Device: pushgateway.Device{ + AppID: pusher.AppID, + Data: data, + PushKey: pusher.PushKey, + PushKeyTS: pusher.PushKeyTS, + Tweaks: tweaks, + }, + Pusher: &pusher, + URL: url, + Format: format, + }) + } + + return devices, nil +} + +// mapWithout returns a shallow copy of the map, without the given +// key. Returns nil if the resulting map is empty. +func mapWithout(m map[string]interface{}, key string) map[string]interface{} { + ret := make(map[string]interface{}, len(m)) + for k, v := range m { + // The specification says we do not send "url". + if k == key { + continue + } + ret[k] = v + } + if len(ret) == 0 { + return nil + } + return ret +} diff --git a/userapi/util/notify.go b/userapi/util/notify.go new file mode 100644 index 000000000..ff206bd3c --- /dev/null +++ b/userapi/util/notify.go @@ -0,0 +1,76 @@ +package util + +import ( + "context" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +// NotifyUserCountsAsync sends notifications to a local user's +// notification destinations. Database lookups run synchronously, but +// a single goroutine is started when talking to the Push +// gateways. There is no way to know when the background goroutine has +// finished. +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) + if err != nil { + return err + } + + if len(pusherDevices) == 0 { + return nil + } + + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevices[0].Device.AppID, + "pushkey": pusherDevices[0].Device.PushKey, + }).Tracef("Notifying HTTP push gateway about notification counts") + + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // TODO: we could batch all devices with the same URL, but + // Sytest requires consumers/roomserver.go to do it + // one-by-one, so we do the same here. + for _, pusherDevice := range pusherDevices { + // TODO: support "email". + if !strings.HasPrefix(pusherDevice.URL, "http") { + continue + } + + req := pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{ + Unread: int(userNumUnreadNotifs), + }, + Devices: []*pushgateway.Device{&pusherDevice.Device}, + }, + } + if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevice.Device.AppID, + "pushkey": pusherDevice.Device.PushKey, + }).WithError(err).Error("HTTP push gateway request failed") + return + } + } + }() + + return nil +}