mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/phonehomestats
This commit is contained in:
commit
c1989c024c
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -54,7 +54,7 @@ dendrite.yaml
|
|||
*.db
|
||||
|
||||
# Log files
|
||||
*.log*
|
||||
*.log*
|
||||
|
||||
# Generated code
|
||||
cmd/dendrite-demo-yggdrasil/embed/fs*.go
|
||||
|
|
@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
|
|||
# Test dependencies
|
||||
test/wasm/node_modules
|
||||
|
||||
media_store/
|
||||
# Ignore complement folder when running locally
|
||||
complement/
|
||||
|
||||
media_store/
|
||||
|
|
|
|||
|
|
@ -318,6 +318,17 @@ user_api:
|
|||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for the Push Server API.
|
||||
push_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7782
|
||||
connect: http://localhost:7782
|
||||
database:
|
||||
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for Opentracing.
|
||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||
# how this works and how to set it up.
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(m.userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ func AddPublicRoutes(
|
|||
routing.Setup(
|
||||
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
|
||||
accountsDB, userAPI, federation,
|
||||
syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg,
|
||||
syncProducer, transactionsCache, fsAPI, keyAPI,
|
||||
extRoomsProvider, mscCfg,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ type SyncAPIProducer struct {
|
|||
}
|
||||
|
||||
// SendData sends account data to the sync API server
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error {
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
|
|
@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
|
|||
m.Header.Set(jetstream.UserID, userID)
|
||||
|
||||
data := eventutil.AccountData{
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
ReadMarker: readMarker,
|
||||
}
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
|
|
@ -127,7 +128,7 @@ func SaveAccountData(
|
|||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
@ -138,11 +139,6 @@ func SaveAccountData(
|
|||
}
|
||||
}
|
||||
|
||||
type readMarkerJSON struct {
|
||||
FullyRead string `json:"m.fully_read"`
|
||||
Read string `json:"m.read"`
|
||||
}
|
||||
|
||||
type fullyReadEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
|
|
@ -159,7 +155,7 @@ func SaveReadMarker(
|
|||
return *resErr
|
||||
}
|
||||
|
||||
var r readMarkerJSON
|
||||
var r eventutil.ReadMarkerJSON
|
||||
resErr = httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
|
|
@ -189,7 +185,7 @@ func SaveReadMarker(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil {
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
|
|||
63
clientapi/routing/notification.go
Normal file
63
clientapi/routing/notification.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// GetNotifications handles /_matrix/client/r0/notifications
|
||||
func GetNotifications(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
var limit int64
|
||||
if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
|
||||
var err error
|
||||
limit, err = strconv.ParseInt(limitStr, 10, 64)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
|
||||
var queryRes userapi.QueryNotificationsResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||
Localpart: localpart,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: queryRes,
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,7 @@ import (
|
|||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type newPasswordRequest struct {
|
||||
|
|
@ -37,6 +38,11 @@ func Password(
|
|||
var r newPasswordRequest
|
||||
r.LogoutDevices = true
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"sessionId": device.SessionID,
|
||||
"userId": device.UserID,
|
||||
}).Debug("Changing password")
|
||||
|
||||
// Unmarshal the request.
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
|
|
@ -116,6 +122,15 @@ func Password(
|
|||
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success code.
|
||||
|
|
|
|||
114
clientapi/routing/pusher.go
Normal file
114
clientapi/routing/pusher.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// GetPushers handles /_matrix/client/r0/pushers
|
||||
func GetPushers(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
var queryRes userapi.QueryPushersResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||
Localpart: localpart,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
for i := range queryRes.Pushers {
|
||||
queryRes.Pushers[i].SessionID = 0
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: queryRes,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPusher handles /_matrix/client/r0/pushers/set
|
||||
// This endpoint allows the creation, modification and deletion of pushers for this user ID.
|
||||
// The behaviour of this endpoint varies depending on the values in the JSON body.
|
||||
func SetPusher(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
body := userapi.PerformPusherSetRequest{}
|
||||
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
if len(body.AppID) > 64 {
|
||||
return invalidParam("length of app_id must be no more than 64 characters")
|
||||
}
|
||||
if len(body.PushKey) > 512 {
|
||||
return invalidParam("length of pushkey must be no more than 512 bytes")
|
||||
}
|
||||
uInt := body.Data["url"]
|
||||
if uInt != nil {
|
||||
u, ok := uInt.(string)
|
||||
if !ok {
|
||||
return invalidParam("url must be string")
|
||||
}
|
||||
if u != "" {
|
||||
var pushUrl *url.URL
|
||||
pushUrl, err = url.Parse(u)
|
||||
if err != nil {
|
||||
return invalidParam("malformed url passed")
|
||||
}
|
||||
if pushUrl.Scheme != "https" {
|
||||
return invalidParam("only https scheme is allowed")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
body.Localpart = localpart
|
||||
body.SessionID = device.SessionID
|
||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func invalidParam(msg string) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidParam(msg),
|
||||
}
|
||||
}
|
||||
386
clientapi/routing/pushrules.go
Normal file
386
clientapi/routing/pushrules.go
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
|
||||
if eerr, ok := err.(*jsonerror.MatrixError); ok {
|
||||
var status int
|
||||
switch eerr.ErrCode {
|
||||
case "M_INVALID_ARGUMENT_VALUE":
|
||||
status = http.StatusBadRequest
|
||||
case "M_NOT_FOUND":
|
||||
status = http.StatusNotFound
|
||||
default:
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
|
||||
}
|
||||
util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: ruleSets,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: ruleSet,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: *rulesPtr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: (*rulesPtr)[i],
|
||||
}
|
||||
}
|
||||
|
||||
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
var newRule pushrules.Rule
|
||||
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
|
||||
return errorResponse(ctx, err, "JSON Decode failed")
|
||||
}
|
||||
newRule.RuleID = ruleID
|
||||
|
||||
errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
|
||||
if len(errs) > 0 {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
|
||||
}
|
||||
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
|
||||
// Modify rule at the same index.
|
||||
|
||||
// TODO: The spec does not say what to do in this case, but
|
||||
// this feels reasonable.
|
||||
*((*rulesPtr)[i]) = newRule
|
||||
util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
|
||||
} else {
|
||||
if i >= 0 {
|
||||
// Delete old rule.
|
||||
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||
util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
|
||||
} else {
|
||||
// SPEC: When creating push rules, they MUST be enabled by default.
|
||||
//
|
||||
// TODO: it's unclear if we must reject disabled rules, or force
|
||||
// the value to true. Sytests fail if we don't force it.
|
||||
newRule.Enabled = true
|
||||
}
|
||||
|
||||
// Add new rule.
|
||||
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
|
||||
}
|
||||
|
||||
*rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
|
||||
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
|
||||
}
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
|
||||
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
attrGet, err := pushRuleAttrGetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||
}
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: map[string]interface{}{
|
||||
attr: attrGet((*rulesPtr)[i]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
var newPartialRule pushrules.Rule
|
||||
if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(err.Error()),
|
||||
}
|
||||
}
|
||||
if newPartialRule.Actions == nil {
|
||||
// This ensures json.Marshal encodes the empty list as [] rather than null.
|
||||
newPartialRule.Actions = []*pushrules.Action{}
|
||||
}
|
||||
|
||||
attrGet, err := pushRuleAttrGetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||
}
|
||||
attrSet, err := pushRuleAttrSetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
|
||||
}
|
||||
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
|
||||
attrSet((*rulesPtr)[i], &newPartialRule)
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
|
||||
var res userapi.QueryPushRulesResponse
|
||||
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
|
||||
return nil, err
|
||||
}
|
||||
return res.RuleSets, nil
|
||||
}
|
||||
|
||||
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
|
||||
req := userapi.PerformPushRulesPutRequest{
|
||||
UserID: userID,
|
||||
RuleSets: ruleSets,
|
||||
}
|
||||
var res struct{}
|
||||
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
|
||||
switch scope {
|
||||
case pushrules.GlobalScope:
|
||||
return &ruleSets.Global
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
|
||||
switch kind {
|
||||
case pushrules.OverrideKind:
|
||||
return &ruleSet.Override
|
||||
case pushrules.ContentKind:
|
||||
return &ruleSet.Content
|
||||
case pushrules.RoomKind:
|
||||
return &ruleSet.Room
|
||||
case pushrules.SenderKind:
|
||||
return &ruleSet.Sender
|
||||
case pushrules.UnderrideKind:
|
||||
return &ruleSet.Underride
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
|
||||
for i, rule := range rules {
|
||||
if rule.RuleID == id {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
|
||||
switch attr {
|
||||
case "actions":
|
||||
return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
|
||||
case "enabled":
|
||||
return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
|
||||
default:
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
|
||||
switch attr {
|
||||
case "actions":
|
||||
return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
|
||||
case "enabled":
|
||||
return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
|
||||
default:
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||
}
|
||||
}
|
||||
|
||||
func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
|
||||
var i int
|
||||
|
||||
if afterID != "" {
|
||||
for ; i < len(rules); i++ {
|
||||
if rules[i].RuleID == afterID {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == len(rules) {
|
||||
return 0, jsonerror.NotFound("after: rule ID not found")
|
||||
}
|
||||
if rules[i].Default {
|
||||
return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
|
||||
}
|
||||
// We stopped on the "after" match to differentiate
|
||||
// not-found from is-last-entry. Now we move to the earliest
|
||||
// insertion point.
|
||||
i++
|
||||
}
|
||||
|
||||
if beforeID != "" {
|
||||
for ; i < len(rules); i++ {
|
||||
if rules[i].RuleID == beforeID {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == len(rules) {
|
||||
return 0, jsonerror.NotFound("before: rule ID not found")
|
||||
}
|
||||
if rules[i].Default {
|
||||
return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
|
||||
}
|
||||
}
|
||||
|
||||
// UNSPEC: The spec does not say what to do if no after/before is
|
||||
// given. Sytest fails if it doesn't go first.
|
||||
return i, nil
|
||||
}
|
||||
|
|
@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
s := newSessionsDict()
|
||||
|
||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld"
|
||||
// manually added, as s.addParams() would start the timer with the default timeout
|
||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld2"
|
||||
s.startTimer(time.Minute, dummySession)
|
||||
s.deleteSession(dummySession)
|
||||
|
|
@ -236,7 +236,7 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld3"
|
||||
// the following will start a timer with the default timeout of 5min
|
||||
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||
|
|
@ -246,7 +246,7 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
s.getCompletedStages(dummySession)
|
||||
// reset the timer with a lower timeout
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
|
|
@ -260,4 +260,4 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
t.Error("expected session to device to be delete")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func PutTag(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ func DeleteTag(
|
|||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
|
@ -561,25 +560,142 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
|
||||
// TODO: Implement push rules API
|
||||
res := json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
"room": [],
|
||||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`)
|
||||
// Push rules
|
||||
|
||||
v3mux.Handle("/pushrules",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: &res,
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetAllPushRules(req.Context(), device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
query := req.URL.Query()
|
||||
return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodDelete)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
// Element user settings
|
||||
|
||||
v3mux.Handle("/profile/{userID}",
|
||||
|
|
@ -885,6 +1001,27 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/notifications",
|
||||
httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetNotifications(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushers",
|
||||
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetPushers(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushers/set",
|
||||
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return SetPusher(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// Stub implementations for sytest
|
||||
v3mux.Handle("/events",
|
||||
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -144,12 +144,14 @@ func main() {
|
|||
accountDB := base.Base.CreateAccountsDB()
|
||||
federation := createFederationClient(base)
|
||||
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(
|
||||
&base.Base,
|
||||
)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
&base.Base, cache.New(), userAPI,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ func main() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -111,14 +111,15 @@ func main() {
|
|||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
rsComponent := roomserver.NewInternalAPI(
|
||||
base,
|
||||
)
|
||||
rsAPI := rsComponent
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
base, cache.New(), userAPI,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,7 +106,8 @@ func main() {
|
|||
keyAPI = base.KeyServerHTTPClient()
|
||||
}
|
||||
|
||||
userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
pgClient := base.PushGatewayHTTPClient()
|
||||
userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
|
||||
userAPI := userImpl
|
||||
if base.UseHTTPAPIs {
|
||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,11 @@ import (
|
|||
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
||||
accountDB := base.CreateAccountsDB()
|
||||
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
|
||||
userAPI := userapi.NewInternalAPI(
|
||||
base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
|
||||
base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
|
||||
base.PushGatewayHTTPClient(),
|
||||
)
|
||||
|
||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||
|
||||
|
|
|
|||
|
|
@ -184,13 +184,15 @@ func startup() {
|
|||
accountDB := base.CreateAccountsDB()
|
||||
federation := conn.CreateFederationClient(base, pSessions)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
||||
asQuery := appservice.NewInternalAPI(
|
||||
base, userAPI, rsAPI,
|
||||
|
|
|
|||
|
|
@ -212,6 +212,8 @@ func main() {
|
|||
rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
|
||||
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
|
||||
|
||||
psAPI := pushserver.NewInternalAPI(base)
|
||||
|
||||
monolith := setup.Monolith{
|
||||
Config: base.Cfg,
|
||||
AccountDB: accountDB,
|
||||
|
|
@ -225,6 +227,7 @@ func main() {
|
|||
RoomserverAPI: rsAPI,
|
||||
UserAPI: userAPI,
|
||||
KeyAPI: keyAPI,
|
||||
PushserverAPI: psAPI,
|
||||
//ServerKeyAPI: serverKeyAPI,
|
||||
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
#
|
||||
# At a minimum, to get started, you will need to update the settings in the
|
||||
# "global" section for your deployment, and you will need to check that the
|
||||
# database "connection_string" line in each component section is correct.
|
||||
# database "connection_string" line in each component section is correct.
|
||||
#
|
||||
# Each component with a "database" section can accept the following formats
|
||||
# for "connection_string":
|
||||
|
|
@ -21,13 +21,13 @@
|
|||
# small number of users and likely will perform worse still with a higher volume
|
||||
# of users.
|
||||
#
|
||||
# The "max_open_conns" and "max_idle_conns" settings configure the maximum
|
||||
# The "max_open_conns" and "max_idle_conns" settings configure the maximum
|
||||
# number of open/idle database connections. The value 0 will use the database
|
||||
# engine default, and a negative value will use unlimited connections. The
|
||||
# "conn_max_lifetime" option controls the maximum length of time a database
|
||||
# connection can be idle in seconds - a negative value is unlimited.
|
||||
|
||||
# The version of the configuration file.
|
||||
# The version of the configuration file.
|
||||
version: 2
|
||||
|
||||
# Global Matrix configuration. This configuration applies to all components.
|
||||
|
|
@ -61,8 +61,8 @@ global:
|
|||
# Lists of domains that the server will trust as identity servers to verify third
|
||||
# party identifiers such as phone numbers and email addresses.
|
||||
trusted_third_party_id_servers:
|
||||
- matrix.org
|
||||
- vector.im
|
||||
- matrix.org
|
||||
- vector.im
|
||||
|
||||
# Disables federation. Dendrite will not be able to make any outbound HTTP requests
|
||||
# to other servers and the federation API will not be exposed.
|
||||
|
|
@ -95,14 +95,14 @@ global:
|
|||
# in monolith mode. It is required to specify the address of at least one
|
||||
# NATS Server node if running in polylith mode.
|
||||
addresses:
|
||||
# - localhost:4222
|
||||
# - localhost:4222
|
||||
|
||||
# Keep all NATS streams in memory, rather than persisting it to the storage
|
||||
# path below. This option is present primarily for integration testing and
|
||||
# should not be used on a real world Dendrite deployment.
|
||||
in_memory: false
|
||||
|
||||
# Persistent directory to store JetStream streams in. This directory
|
||||
# Persistent directory to store JetStream streams in. This directory
|
||||
# should be preserved across Dendrite restarts.
|
||||
storage_path: ./
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ global:
|
|||
# Configuration for the Appservice API.
|
||||
app_service_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7777 # Only used in polylith deployments
|
||||
listen: http://localhost:7777 # Only used in polylith deployments
|
||||
connect: http://localhost:7777 # Only used in polylith deployments
|
||||
database:
|
||||
connection_string: file:appservice.db
|
||||
|
|
@ -153,7 +153,7 @@ app_service_api:
|
|||
# Configuration for the Client API.
|
||||
client_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7771 # Only used in polylith deployments
|
||||
listen: http://localhost:7771 # Only used in polylith deployments
|
||||
connect: http://localhost:7771 # Only used in polylith deployments
|
||||
external_api:
|
||||
listen: http://[::]:8071
|
||||
|
|
@ -173,13 +173,13 @@ client_api:
|
|||
# Whether to require reCAPTCHA for registration.
|
||||
enable_registration_captcha: false
|
||||
|
||||
# Settings for ReCAPTCHA.
|
||||
# Settings for ReCAPTCHA.
|
||||
recaptcha_public_key: ""
|
||||
recaptcha_private_key: ""
|
||||
recaptcha_bypass_secret: ""
|
||||
recaptcha_siteverify_api: ""
|
||||
|
||||
# TURN server information that this homeserver should send to clients.
|
||||
# TURN server information that this homeserver should send to clients.
|
||||
turn:
|
||||
turn_user_lifetime: ""
|
||||
turn_uris: []
|
||||
|
|
@ -188,7 +188,7 @@ client_api:
|
|||
turn_password: ""
|
||||
|
||||
# Settings for rate-limited endpoints. Rate limiting will kick in after the
|
||||
# threshold number of "slots" have been taken by requests from a specific
|
||||
# threshold number of "slots" have been taken by requests from a specific
|
||||
# host. Each "slot" will be released after the cooloff time in milliseconds.
|
||||
rate_limiting:
|
||||
enabled: true
|
||||
|
|
@ -198,13 +198,13 @@ client_api:
|
|||
# Configuration for the EDU server.
|
||||
edu_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7778 # Only used in polylith deployments
|
||||
listen: http://localhost:7778 # Only used in polylith deployments
|
||||
connect: http://localhost:7778 # Only used in polylith deployments
|
||||
|
||||
# Configuration for the Federation API.
|
||||
federation_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7772 # Only used in polylith deployments
|
||||
listen: http://localhost:7772 # Only used in polylith deployments
|
||||
connect: http://localhost:7772 # Only used in polylith deployments
|
||||
external_api:
|
||||
listen: http://[::]:8072
|
||||
|
|
@ -232,12 +232,12 @@ federation_api:
|
|||
# be required to satisfy key requests for servers that are no longer online when
|
||||
# joining some rooms.
|
||||
key_perspectives:
|
||||
- server_name: matrix.org
|
||||
keys:
|
||||
- key_id: ed25519:auto
|
||||
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
|
||||
- key_id: ed25519:a_RXGa
|
||||
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
||||
- server_name: matrix.org
|
||||
keys:
|
||||
- key_id: ed25519:auto
|
||||
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
|
||||
- key_id: ed25519:a_RXGa
|
||||
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
||||
|
||||
# This option will control whether Dendrite will prefer to look up keys directly
|
||||
# or whether it should try perspective servers first, using direct fetches as a
|
||||
|
|
@ -247,7 +247,7 @@ federation_api:
|
|||
# Configuration for the Key Server (for end-to-end encryption).
|
||||
key_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7779 # Only used in polylith deployments
|
||||
listen: http://localhost:7779 # Only used in polylith deployments
|
||||
connect: http://localhost:7779 # Only used in polylith deployments
|
||||
database:
|
||||
connection_string: file:keyserver.db
|
||||
|
|
@ -258,7 +258,7 @@ key_server:
|
|||
# Configuration for the Media API.
|
||||
media_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7774 # Only used in polylith deployments
|
||||
listen: http://localhost:7774 # Only used in polylith deployments
|
||||
connect: http://localhost:7774 # Only used in polylith deployments
|
||||
external_api:
|
||||
listen: http://[::]:8074
|
||||
|
|
@ -284,15 +284,15 @@ media_api:
|
|||
|
||||
# A list of thumbnail sizes to be generated for media content.
|
||||
thumbnail_sizes:
|
||||
- width: 32
|
||||
height: 32
|
||||
method: crop
|
||||
- width: 96
|
||||
height: 96
|
||||
method: crop
|
||||
- width: 640
|
||||
height: 480
|
||||
method: scale
|
||||
- width: 32
|
||||
height: 32
|
||||
method: crop
|
||||
- width: 96
|
||||
height: 96
|
||||
method: crop
|
||||
- width: 640
|
||||
height: 480
|
||||
method: scale
|
||||
|
||||
# Configuration for experimental MSC's
|
||||
mscs:
|
||||
|
|
@ -310,7 +310,7 @@ mscs:
|
|||
# Configuration for the Room Server.
|
||||
room_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7770 # Only used in polylith deployments
|
||||
listen: http://localhost:7770 # Only used in polylith deployments
|
||||
connect: http://localhost:7770 # Only used in polylith deployments
|
||||
database:
|
||||
connection_string: file:roomserver.db
|
||||
|
|
@ -321,7 +321,7 @@ room_server:
|
|||
# Configuration for the Sync API.
|
||||
sync_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7773 # Only used in polylith deployments
|
||||
listen: http://localhost:7773 # Only used in polylith deployments
|
||||
connect: http://localhost:7773 # Only used in polylith deployments
|
||||
external_api:
|
||||
listen: http://[::]:8073
|
||||
|
|
@ -346,16 +346,16 @@ user_api:
|
|||
# This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds)
|
||||
# bcrypt_cost: 10
|
||||
internal_api:
|
||||
listen: http://localhost:7781 # Only used in polylith deployments
|
||||
listen: http://localhost:7781 # Only used in polylith deployments
|
||||
connect: http://localhost:7781 # Only used in polylith deployments
|
||||
account_database:
|
||||
connection_string: file:userapi_accounts.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
# The length of time that a token issued for a relying party from
|
||||
# The length of time that a token issued for a relying party from
|
||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||
# is considered to be valid in milliseconds.
|
||||
# is considered to be valid in milliseconds.
|
||||
# The default lifetime is 3600000ms (60 minutes).
|
||||
# openid_token_lifetime_ms: 3600000
|
||||
|
||||
|
|
@ -377,10 +377,10 @@ tracing:
|
|||
|
||||
# Logging configuration
|
||||
logging:
|
||||
- type: std
|
||||
level: info
|
||||
- type: file
|
||||
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
||||
level: info
|
||||
params:
|
||||
path: ./logs
|
||||
- type: std
|
||||
level: info
|
||||
- type: file
|
||||
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
||||
level: info
|
||||
params:
|
||||
path: ./logs
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -18,6 +18,7 @@ require (
|
|||
github.com/frankban/quicktest v1.14.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
|
|
|
|||
|
|
@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
|
|||
// AccountData represents account data sent from the client API server to the
|
||||
// sync API server
|
||||
type AccountData struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
|
||||
}
|
||||
|
||||
type ReadMarkerJSON struct {
|
||||
FullyRead string `json:"m.fully_read"`
|
||||
Read string `json:"m.read"`
|
||||
}
|
||||
|
||||
// NotificationData contains statistics about notifications, sent from
|
||||
// the Push Server to the Sync API server.
|
||||
type NotificationData struct {
|
||||
// RoomID identifies the scope of the statistics, together with
|
||||
// MXID (which is encoded in the Kafka key).
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// HighlightCount is the number of unread notifications with the
|
||||
// highlight tweak.
|
||||
UnreadHighlightCount int `json:"unread_highlight_count"`
|
||||
|
||||
// UnreadNotificationCount is the total number of unread
|
||||
// notifications.
|
||||
UnreadNotificationCount int `json:"unread_notification_count"`
|
||||
}
|
||||
|
||||
// ProfileResponse is a struct containing all known user profile data
|
||||
|
|
|
|||
66
internal/pushgateway/client.go
Normal file
66
internal/pushgateway/client.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package pushgateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
type httpClient struct {
|
||||
hc *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a new Push Gateway client.
|
||||
func NewHTTPClient(disableTLSValidation bool) Client {
|
||||
hc := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: disableTLSValidation,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &httpClient{hc: hc}
|
||||
}
|
||||
|
||||
func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
|
||||
defer span.Finish()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hreq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
hresp, err := h.hc.Do(hreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint:errcheck
|
||||
defer hresp.Body.Close()
|
||||
|
||||
if hresp.StatusCode == http.StatusOK {
|
||||
return json.NewDecoder(hresp.Body).Decode(resp)
|
||||
}
|
||||
|
||||
var errorBody struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
|
||||
return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
|
||||
}
|
||||
return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
|
||||
}
|
||||
62
internal/pushgateway/pushgateway.go
Normal file
62
internal/pushgateway/pushgateway.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package pushgateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Client is how interactions with a Push Gateway is done.
|
||||
type Client interface {
|
||||
// Notify sends a notification to the gateway at the given URL.
|
||||
Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
|
||||
}
|
||||
|
||||
type NotifyRequest struct {
|
||||
Notification Notification `json:"notification"` // Required
|
||||
}
|
||||
|
||||
type NotifyResponse struct {
|
||||
// Rejected is the list of device push keys that were rejected
|
||||
// during the push. The caller should remove the push keys so they
|
||||
// are not used again.
|
||||
Rejected []string `json:"rejected"` // Required
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Counts *Counts `json:"counts,omitempty"`
|
||||
Devices []*Device `json:"devices"` // Required
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
ID string `json:"id,omitempty"` // Deprecated name for EventID.
|
||||
Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
|
||||
Prio Prio `json:"prio,omitempty"`
|
||||
RoomAlias string `json:"room_alias,omitempty"`
|
||||
RoomID string `json:"room_id,omitempty"`
|
||||
RoomName string `json:"room_name,omitempty"`
|
||||
Sender string `json:"sender,omitempty"`
|
||||
SenderDisplayName string `json:"sender_display_name,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
UserIsTarget bool `json:"user_is_target,omitempty"`
|
||||
}
|
||||
|
||||
type Counts struct {
|
||||
MissedCalls int `json:"missed_calls,omitempty"`
|
||||
Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
AppID string `json:"app_id"` // Required
|
||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||
PushKey string `json:"pushkey"` // Required
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||
}
|
||||
|
||||
type Prio string
|
||||
|
||||
const (
|
||||
HighPrio Prio = "high"
|
||||
LowPrio Prio = "low"
|
||||
)
|
||||
102
internal/pushrules/action.go
Normal file
102
internal/pushrules/action.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// An Action is (part of) an outcome of a rule. There are
|
||||
// (unofficially) terminal actions, and modifier actions.
|
||||
type Action struct {
|
||||
// Kind is the type of action. Has custom encoding in JSON.
|
||||
Kind ActionKind `json:"-"`
|
||||
|
||||
// Tweak is the property to tweak. Has custom encoding in JSON.
|
||||
Tweak TweakKey `json:"-"`
|
||||
|
||||
// Value is some value interpreted according to Kind and Tweak.
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
func (a *Action) MarshalJSON() ([]byte, error) {
|
||||
if a.Tweak == UnknownTweak && a.Value == nil {
|
||||
return json.Marshal(a.Kind)
|
||||
}
|
||||
|
||||
if a.Kind != SetTweakAction {
|
||||
return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
|
||||
}
|
||||
|
||||
m := map[string]interface{}{
|
||||
string(a.Kind): a.Tweak,
|
||||
}
|
||||
if a.Value != nil {
|
||||
m["value"] = a.Value
|
||||
}
|
||||
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (a *Action) UnmarshalJSON(bs []byte) error {
|
||||
if bytes.HasPrefix(bs, []byte("\"")) {
|
||||
return json.Unmarshal(bs, &a.Kind)
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
SetTweak TweakKey `json:"set_tweak"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(bs, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
if raw.SetTweak == UnknownTweak {
|
||||
return fmt.Errorf("got unknown action JSON: %s", string(bs))
|
||||
}
|
||||
a.Kind = SetTweakAction
|
||||
a.Tweak = raw.SetTweak
|
||||
if raw.Value != nil {
|
||||
a.Value = raw.Value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActionKind is the primary discriminator for actions.
|
||||
type ActionKind string
|
||||
|
||||
const (
|
||||
UnknownAction ActionKind = ""
|
||||
|
||||
// NotifyAction indicates the clients should show a notification.
|
||||
NotifyAction ActionKind = "notify"
|
||||
|
||||
// DontNotifyAction indicates the clients should not show a notification.
|
||||
DontNotifyAction ActionKind = "dont_notify"
|
||||
|
||||
// CoalesceAction tells the clients to show a notification, and
|
||||
// tells both servers and clients that multiple events can be
|
||||
// coalesced into a single notification. The behaviour is
|
||||
// implementation-specific.
|
||||
CoalesceAction ActionKind = "coalesce"
|
||||
|
||||
// SetTweakAction uses the Tweak and Value fields to add a
|
||||
// tweak. Multiple SetTweakAction can be provided in a rule,
|
||||
// combined with NotifyAction or CoalesceAction.
|
||||
SetTweakAction ActionKind = "set_tweak"
|
||||
)
|
||||
|
||||
// A TweakKey describes a property to be modified/tweaked for events
|
||||
// that match the rule.
|
||||
type TweakKey string
|
||||
|
||||
const (
|
||||
UnknownTweak TweakKey = ""
|
||||
|
||||
// SoundTweak describes which sound to play. Using "default" means
|
||||
// "enable sound".
|
||||
SoundTweak TweakKey = "sound"
|
||||
|
||||
// HighlightTweak asks the clients to highlight the conversation.
|
||||
HighlightTweak TweakKey = "highlight"
|
||||
)
|
||||
39
internal/pushrules/action_test.go
Normal file
39
internal/pushrules/action_test.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestActionJSON(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Want Action
|
||||
}{
|
||||
{Action{Kind: NotifyAction}},
|
||||
{Action{Kind: DontNotifyAction}},
|
||||
{Action{Kind: CoalesceAction}},
|
||||
{Action{Kind: SetTweakAction}},
|
||||
|
||||
{Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
|
||||
{Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
|
||||
{Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
|
||||
bs, err := json.Marshal(&tst.Want)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var got Action
|
||||
if err := json.Unmarshal(bs, &got); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("+got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
49
internal/pushrules/condition.go
Normal file
49
internal/pushrules/condition.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package pushrules
|
||||
|
||||
// A Condition dictates extra conditions for a matching rules. See
|
||||
// ConditionKind.
|
||||
type Condition struct {
|
||||
// Kind is the primary discriminator for the condition
|
||||
// type. Required.
|
||||
Kind ConditionKind `json:"kind"`
|
||||
|
||||
// Key indicates the dot-separated path of Event fields to
|
||||
// match. Required for EventMatchCondition and
|
||||
// SenderNotificationPermissionCondition.
|
||||
Key string `json:"key,omitempty"`
|
||||
|
||||
// Pattern indicates the value pattern that must match. Required
|
||||
// for EventMatchCondition.
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
|
||||
// Is indicates the condition that must be fulfilled. Required for
|
||||
// RoomMemberCountCondition.
|
||||
Is string `json:"is,omitempty"`
|
||||
}
|
||||
|
||||
// ConditionKind represents a kind of condition.
|
||||
//
|
||||
// SPEC: Unrecognised conditions MUST NOT match any events,
|
||||
// effectively making the push rule disabled.
|
||||
type ConditionKind string
|
||||
|
||||
const (
|
||||
UnknownCondition ConditionKind = ""
|
||||
|
||||
// EventMatchCondition indicates the condition looks for a key
|
||||
// path and matches a pattern. How paths that don't reference a
|
||||
// simple value match against rules is implementation-specific.
|
||||
EventMatchCondition ConditionKind = "event_match"
|
||||
|
||||
// ContainsDisplayNameCondition indicates the current user's
|
||||
// display name must be found in the content body.
|
||||
ContainsDisplayNameCondition ConditionKind = "contains_display_name"
|
||||
|
||||
// RoomMemberCountCondition matches a simple arithmetic comparison
|
||||
// against the total number of members in a room.
|
||||
RoomMemberCountCondition ConditionKind = "room_member_count"
|
||||
|
||||
// SenderNotificationPermissionCondition compares power level for
|
||||
// the sender in the event's room.
|
||||
SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
|
||||
)
|
||||
23
internal/pushrules/default.go
Normal file
23
internal/pushrules/default.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// DefaultAccountRuleSets is the complete set of default push rules
|
||||
// for an account.
|
||||
func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
|
||||
return &AccountRuleSets{
|
||||
Global: *DefaultGlobalRuleSet(localpart, serverName),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultGlobalRuleSet returns the default ruleset for a given (fully
|
||||
// qualified) MXID.
|
||||
func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
|
||||
return &RuleSet{
|
||||
Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
|
||||
Content: defaultContentRules(localpart),
|
||||
Underride: defaultUnderrideRules,
|
||||
}
|
||||
}
|
||||
33
internal/pushrules/default_content.go
Normal file
33
internal/pushrules/default_content.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package pushrules
|
||||
|
||||
func defaultContentRules(localpart string) []*Rule {
|
||||
return []*Rule{
|
||||
mRuleContainsUserNameDefinition(localpart),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
MRuleContainsUserName = ".m.rule.contains_user_name"
|
||||
)
|
||||
|
||||
func mRuleContainsUserNameDefinition(localpart string) *Rule {
|
||||
return &Rule{
|
||||
RuleID: MRuleContainsUserName,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Pattern: localpart,
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
165
internal/pushrules/default_override.go
Normal file
165
internal/pushrules/default_override.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package pushrules
|
||||
|
||||
func defaultOverrideRules(userID string) []*Rule {
|
||||
return []*Rule{
|
||||
&mRuleMasterDefinition,
|
||||
&mRuleSuppressNoticesDefinition,
|
||||
mRuleInviteForMeDefinition(userID),
|
||||
&mRuleMemberEventDefinition,
|
||||
&mRuleContainsDisplayNameDefinition,
|
||||
&mRuleTombstoneDefinition,
|
||||
&mRuleRoomNotifDefinition,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
MRuleMaster = ".m.rule.master"
|
||||
MRuleSuppressNotices = ".m.rule.suppress_notices"
|
||||
MRuleInviteForMe = ".m.rule.invite_for_me"
|
||||
MRuleMemberEvent = ".m.rule.member_event"
|
||||
MRuleContainsDisplayName = ".m.rule.contains_display_name"
|
||||
MRuleTombstone = ".m.rule.tombstone"
|
||||
MRuleRoomNotif = ".m.rule.roomnotif"
|
||||
)
|
||||
|
||||
var (
|
||||
mRuleMasterDefinition = Rule{
|
||||
RuleID: MRuleMaster,
|
||||
Default: true,
|
||||
Enabled: false,
|
||||
Conditions: []*Condition{},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleSuppressNoticesDefinition = Rule{
|
||||
RuleID: MRuleSuppressNotices,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.msgtype",
|
||||
Pattern: "m.notice",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleMemberEventDefinition = Rule{
|
||||
RuleID: MRuleMemberEvent,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.member",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleContainsDisplayNameDefinition = Rule{
|
||||
RuleID: MRuleContainsDisplayName,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleTombstoneDefinition = Rule{
|
||||
RuleID: MRuleTombstone,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.tombstone",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "state_key",
|
||||
Pattern: "",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleRoomNotifDefinition = Rule{
|
||||
RuleID: MRuleRoomNotif,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.body",
|
||||
Pattern: "@room",
|
||||
},
|
||||
{
|
||||
Kind: SenderNotificationPermissionCondition,
|
||||
Key: "room",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func mRuleInviteForMeDefinition(userID string) *Rule {
|
||||
return &Rule{
|
||||
RuleID: MRuleInviteForMe,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.member",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.membership",
|
||||
Pattern: "invite",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "state_key",
|
||||
Pattern: userID,
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
119
internal/pushrules/default_underride.go
Normal file
119
internal/pushrules/default_underride.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package pushrules
|
||||
|
||||
const (
|
||||
MRuleCall = ".m.rule.call"
|
||||
MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
|
||||
MRuleRoomOneToOne = ".m.rule.room_one_to_one"
|
||||
MRuleMessage = ".m.rule.message"
|
||||
MRuleEncrypted = ".m.rule.encrypted"
|
||||
)
|
||||
|
||||
var defaultUnderrideRules = []*Rule{
|
||||
&mRuleCallDefinition,
|
||||
&mRuleEncryptedRoomOneToOneDefinition,
|
||||
&mRuleRoomOneToOneDefinition,
|
||||
&mRuleMessageDefinition,
|
||||
&mRuleEncryptedDefinition,
|
||||
}
|
||||
|
||||
var (
|
||||
mRuleCallDefinition = Rule{
|
||||
RuleID: MRuleCall,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.call.invite",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "ring",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleEncryptedRoomOneToOneDefinition = Rule{
|
||||
RuleID: MRuleEncryptedRoomOneToOne,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: RoomMemberCountCondition,
|
||||
Is: "2",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.encrypted",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleRoomOneToOneDefinition = Rule{
|
||||
RuleID: MRuleRoomOneToOne,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: RoomMemberCountCondition,
|
||||
Is: "2",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.message",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleMessageDefinition = Rule{
|
||||
RuleID: MRuleMessage,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.message",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: NotifyAction}},
|
||||
}
|
||||
mRuleEncryptedDefinition = Rule{
|
||||
RuleID: MRuleEncrypted,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.encrypted",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: NotifyAction}},
|
||||
}
|
||||
)
|
||||
165
internal/pushrules/evaluate.go
Normal file
165
internal/pushrules/evaluate.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A RuleSetEvaluator encapsulates context to evaluate an event
|
||||
// against a rule set.
|
||||
type RuleSetEvaluator struct {
|
||||
ec EvaluationContext
|
||||
ruleSet []kindAndRules
|
||||
}
|
||||
|
||||
// An EvaluationContext gives a RuleSetEvaluator access to the
|
||||
// environment, for rules that require that.
|
||||
type EvaluationContext interface {
|
||||
// UserDisplayName returns the current user's display name.
|
||||
UserDisplayName() string
|
||||
|
||||
// RoomMemberCount returns the number of members in the room of
|
||||
// the current event.
|
||||
RoomMemberCount() (int, error)
|
||||
|
||||
// HasPowerLevel returns whether the user has at least the given
|
||||
// power in the room of the current event.
|
||||
HasPowerLevel(userID, levelKey string) (bool, error)
|
||||
}
|
||||
|
||||
// A kindAndRules is just here to simplify iteration of the (ordered)
|
||||
// kinds of rules.
|
||||
type kindAndRules struct {
|
||||
Kind Kind
|
||||
Rules []*Rule
|
||||
}
|
||||
|
||||
// NewRuleSetEvaluator creates a new evaluator for the given rule set.
|
||||
func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
|
||||
return &RuleSetEvaluator{
|
||||
ec: ec,
|
||||
ruleSet: []kindAndRules{
|
||||
{OverrideKind, ruleSet.Override},
|
||||
{ContentKind, ruleSet.Content},
|
||||
{RoomKind, ruleSet.Room},
|
||||
{SenderKind, ruleSet.Sender},
|
||||
{UnderrideKind, ruleSet.Underride},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// MatchEvent returns the first matching rule. Returns nil if there
|
||||
// was no match rule.
|
||||
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
|
||||
// TODO: server-default rules have lower priority than user rules,
|
||||
// but they are stored together with the user rules. It's a bit
|
||||
// unclear what the specification (11.14.1.4 Predefined rules)
|
||||
// means the ordering should be.
|
||||
//
|
||||
// The most reasonable interpretation is that default overrides
|
||||
// still have lower priority than user content rules, so we
|
||||
// iterate twice.
|
||||
for _, rsat := range rse.ruleSet {
|
||||
for _, defRules := range []bool{false, true} {
|
||||
for _, rule := range rsat.Rules {
|
||||
if rule.Default != defRules {
|
||||
continue
|
||||
}
|
||||
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||
if !rule.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case OverrideKind, UnderrideKind:
|
||||
for _, cond := range rule.Conditions {
|
||||
ok, err := conditionMatches(cond, event, ec)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
|
||||
case ContentKind:
|
||||
// TODO: "These configure behaviour for (unencrypted) messages
|
||||
// that match certain patterns." - Does that mean "content.body"?
|
||||
return patternMatches("content.body", rule.Pattern, event)
|
||||
|
||||
case RoomKind:
|
||||
return rule.RuleID == event.RoomID(), nil
|
||||
|
||||
case SenderKind:
|
||||
return rule.RuleID == event.Sender(), nil
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||
switch cond.Kind {
|
||||
case EventMatchCondition:
|
||||
return patternMatches(cond.Key, cond.Pattern, event)
|
||||
|
||||
case ContainsDisplayNameCondition:
|
||||
return patternMatches("content.body", ec.UserDisplayName(), event)
|
||||
|
||||
case RoomMemberCountCondition:
|
||||
cmp, err := parseRoomMemberCountCondition(cond.Is)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing room_member_count condition: %w", err)
|
||||
}
|
||||
n, err := ec.RoomMemberCount()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("RoomMemberCount failed: %w", err)
|
||||
}
|
||||
return cmp(n), nil
|
||||
|
||||
case SenderNotificationPermissionCondition:
|
||||
return ec.HasPowerLevel(event.Sender(), cond.Key)
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
|
||||
re, err := globToRegexp(pattern)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var eventMap map[string]interface{}
|
||||
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
|
||||
return false, fmt.Errorf("parsing event: %w", err)
|
||||
}
|
||||
v, err := lookupMapPath(strings.Split(key, "."), eventMap)
|
||||
if err != nil {
|
||||
// An unknown path is a benign error that shouldn't stop rule
|
||||
// processing. It's just a non-match.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return re.MatchString(fmt.Sprint(v)), nil
|
||||
}
|
||||
189
internal/pushrules/evaluate_test.go
Normal file
189
internal/pushrules/evaluate_test.go
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
||||
ev := mustEventFromJSON(t, `{}`)
|
||||
defaultEnabled := &Rule{
|
||||
RuleID: ".default.enabled",
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
}
|
||||
userEnabled := &Rule{
|
||||
RuleID: ".user.enabled",
|
||||
Default: false,
|
||||
Enabled: true,
|
||||
}
|
||||
userEnabled2 := &Rule{
|
||||
RuleID: ".user.enabled.2",
|
||||
Default: false,
|
||||
Enabled: true,
|
||||
}
|
||||
tsts := []struct {
|
||||
Name string
|
||||
RuleSet RuleSet
|
||||
Want *Rule
|
||||
}{
|
||||
{"empty", RuleSet{}, nil},
|
||||
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
|
||||
{"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
|
||||
{"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
|
||||
{"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
|
||||
got, err := rse.MatchEvent(ev)
|
||||
if err != nil {
|
||||
t.Fatalf("MatchEvent failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches(t *testing.T) {
|
||||
emptyRule := Rule{Enabled: true}
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
|
||||
{"emptyContent", ContentKind, emptyRule, `{}`, false},
|
||||
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
|
||||
{"emptySender", SenderKind, emptyRule, `{}`, true},
|
||||
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
|
||||
|
||||
{"disabled", OverrideKind, Rule{}, `{}`, false},
|
||||
|
||||
{"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
|
||||
{"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||
|
||||
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
|
||||
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||
|
||||
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
|
||||
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
|
||||
|
||||
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
|
||||
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
|
||||
|
||||
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
|
||||
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ruleMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionMatches(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Cond Condition
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"empty", Condition{}, `{}`, false},
|
||||
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
|
||||
|
||||
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
|
||||
|
||||
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
|
||||
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
|
||||
|
||||
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
|
||||
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
|
||||
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
|
||||
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
|
||||
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
|
||||
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
|
||||
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
|
||||
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
|
||||
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
|
||||
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
|
||||
|
||||
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
|
||||
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("conditionMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeEvaluationContext struct{}
|
||||
|
||||
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
|
||||
func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
|
||||
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
|
||||
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
|
||||
}
|
||||
|
||||
func TestPatternMatches(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Key string
|
||||
Pattern string
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"empty", "", "", `{}`, false},
|
||||
|
||||
// Note that an empty pattern contains no wildcard characters,
|
||||
// which implicitly means "*".
|
||||
{"patternEmpty", "content", "", `{"content":{}}`, true},
|
||||
|
||||
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("patternMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
|
||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return ev
|
||||
}
|
||||
71
internal/pushrules/pushrules.go
Normal file
71
internal/pushrules/pushrules.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package pushrules
|
||||
|
||||
// An AccountRuleSets carries the rule sets associated with an
|
||||
// account.
|
||||
type AccountRuleSets struct {
|
||||
Global RuleSet `json:"global"` // Required
|
||||
}
|
||||
|
||||
// A RuleSet contains all the various push rules for an
|
||||
// account. Listed in decreasing order of priority.
|
||||
type RuleSet struct {
|
||||
Override []*Rule `json:"override,omitempty"`
|
||||
Content []*Rule `json:"content,omitempty"`
|
||||
Room []*Rule `json:"room,omitempty"`
|
||||
Sender []*Rule `json:"sender,omitempty"`
|
||||
Underride []*Rule `json:"underride,omitempty"`
|
||||
}
|
||||
|
||||
// A Rule contains matchers, conditions and final actions. While
|
||||
// evaluating, at most one rule is considered matching.
|
||||
//
|
||||
// Kind and scope are part of the push rules request/responses, but
|
||||
// not of the core data model.
|
||||
type Rule struct {
|
||||
// RuleID is either a free identifier, or the sender's MXID for
|
||||
// SenderKind. Required.
|
||||
RuleID string `json:"rule_id"`
|
||||
|
||||
// Default indicates whether this is a server-defined default, or
|
||||
// a user-provided rule. Required.
|
||||
//
|
||||
// The server-default rules have the lowest priority.
|
||||
Default bool `json:"default"`
|
||||
|
||||
// Enabled allows the user to disable rules while keeping them
|
||||
// around. Required.
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Actions describe the desired outcome, should the rule
|
||||
// match. Required.
|
||||
Actions []*Action `json:"actions"`
|
||||
|
||||
// Conditions provide the rule's conditions for OverrideKind and
|
||||
// UnderrideKind. Not allowed for other kinds.
|
||||
Conditions []*Condition `json:"conditions"`
|
||||
|
||||
// Pattern is the body pattern to match for ContentKind. Required
|
||||
// for that kind. The interpretation is the same as that of
|
||||
// Condition.Pattern.
|
||||
Pattern string `json:"pattern"`
|
||||
}
|
||||
|
||||
// Scope only has one valid value. See also AccountRuleSets.
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
UnknownScope Scope = ""
|
||||
GlobalScope Scope = "global"
|
||||
)
|
||||
|
||||
// Kind is the type of push rule. See also RuleSet.
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
UnknownKind Kind = ""
|
||||
OverrideKind Kind = "override"
|
||||
ContentKind Kind = "content"
|
||||
RoomKind Kind = "room"
|
||||
SenderKind Kind = "sender"
|
||||
UnderrideKind Kind = "underride"
|
||||
)
|
||||
125
internal/pushrules/util.go
Normal file
125
internal/pushrules/util.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ActionsToTweaks converts a list of actions into a primary action
|
||||
// kind and a tweaks map. Returns a nil map if it would have been
|
||||
// empty.
|
||||
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
|
||||
var kind ActionKind
|
||||
tweaks := map[string]interface{}{}
|
||||
|
||||
for _, a := range as {
|
||||
if a.Kind == SetTweakAction {
|
||||
tweaks[string(a.Tweak)] = a.Value
|
||||
continue
|
||||
}
|
||||
if kind != UnknownAction {
|
||||
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
|
||||
}
|
||||
kind = a.Kind
|
||||
}
|
||||
|
||||
if len(tweaks) == 0 {
|
||||
tweaks = nil
|
||||
}
|
||||
|
||||
return kind, tweaks, nil
|
||||
}
|
||||
|
||||
// BoolTweakOr returns the named tweak as a boolean, and returns `def`
|
||||
// on failure.
|
||||
func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
|
||||
v, ok := tweaks[string(key)]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
|
||||
func globToRegexp(pattern string) (*regexp.Regexp, error) {
|
||||
// TODO: It's unclear which glob characters are supported. The only
|
||||
// place this is discussed is for the unrelated "m.policy.rule.*"
|
||||
// events. Assuming, the same: /[*?]/
|
||||
if !strings.ContainsAny(pattern, "*?") {
|
||||
pattern = "*" + pattern + "*"
|
||||
}
|
||||
|
||||
// The defined syntax doesn't allow escaping the glob wildcard
|
||||
// characters, which makes this a straight-forward
|
||||
// replace-after-quote.
|
||||
pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
|
||||
pattern = strings.Replace(pattern, "*", ".*", -1)
|
||||
pattern = strings.Replace(pattern, "?", ".", -1)
|
||||
return regexp.Compile("^(" + pattern + ")$")
|
||||
}
|
||||
|
||||
// globNonMetaRegexp are the characters that are not considered glob
|
||||
// meta-characters (i.e. may need escaping).
|
||||
var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
|
||||
|
||||
// lookupMapPath traverses a hierarchical map structure, like the one
|
||||
// produced by json.Unmarshal, to return the leaf value. Traversing
|
||||
// arrays/slices is not supported, only objects/maps.
|
||||
func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty path")
|
||||
}
|
||||
|
||||
var v interface{} = m
|
||||
for i, key := range path {
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
|
||||
}
|
||||
|
||||
v, ok = m[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
|
||||
// into a function that checks if the argument to it fulfils the
|
||||
// condition.
|
||||
func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
|
||||
var b int
|
||||
var cmp = func(a int) bool { return a == b }
|
||||
switch {
|
||||
case strings.HasPrefix(s, "<="):
|
||||
cmp = func(a int) bool { return a <= b }
|
||||
s = s[2:]
|
||||
case strings.HasPrefix(s, ">="):
|
||||
cmp = func(a int) bool { return a >= b }
|
||||
s = s[2:]
|
||||
case strings.HasPrefix(s, "<"):
|
||||
cmp = func(a int) bool { return a < b }
|
||||
s = s[1:]
|
||||
case strings.HasPrefix(s, ">"):
|
||||
cmp = func(a int) bool { return a > b }
|
||||
s = s[1:]
|
||||
case strings.HasPrefix(s, "=="):
|
||||
// Same cmp as the default.
|
||||
s = s[2:]
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b = int(v)
|
||||
return cmp, nil
|
||||
}
|
||||
169
internal/pushrules/util_test.go
Normal file
169
internal/pushrules/util_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestActionsToTweaks(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Input []*Action
|
||||
WantKind ActionKind
|
||||
WantTweaks map[string]interface{}
|
||||
}{
|
||||
{"empty", nil, UnknownAction, nil},
|
||||
{"zero", []*Action{{}}, UnknownAction, nil},
|
||||
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
|
||||
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
|
||||
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
|
||||
{
|
||||
"all",
|
||||
[]*Action{
|
||||
{Kind: CoalesceAction},
|
||||
{Kind: SetTweakAction, Tweak: HighlightTweak},
|
||||
{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
|
||||
},
|
||||
CoalesceAction,
|
||||
map[string]interface{}{"highlight": nil, "sound": "default"},
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("ActionsToTweaks failed: %v", err)
|
||||
}
|
||||
if gotKind != tst.WantKind {
|
||||
t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
|
||||
}
|
||||
if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
|
||||
t.Errorf("tweaks: +got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolTweakOr(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Input map[string]interface{}
|
||||
Def bool
|
||||
Want bool
|
||||
}{
|
||||
{"nil", nil, false, false},
|
||||
{"nilValue", map[string]interface{}{"highlight": nil}, true, true},
|
||||
{"false", map[string]interface{}{"highlight": false}, true, false},
|
||||
{"true", map[string]interface{}{"highlight": true}, false, true},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
|
||||
if got != tst.Want {
|
||||
t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobToRegexp(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Input string
|
||||
Want string
|
||||
}{
|
||||
{"", "^(.*.*)$"},
|
||||
{"a", "^(.*a.*)$"},
|
||||
{"a.b", "^(.*a\\.b.*)$"},
|
||||
{"a?b", "^(a.b)$"},
|
||||
{"a*b*", "^(a.*b.*)$"},
|
||||
{"a*b?", "^(a.*b.)$"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Want, func(t *testing.T) {
|
||||
got, err := globToRegexp(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("globToRegexp failed: %v", err)
|
||||
}
|
||||
if got.String() != tst.Want {
|
||||
t.Errorf("got %v, want %v", got.String(), tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupMapPath(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Path []string
|
||||
Root map[string]interface{}
|
||||
Want interface{}
|
||||
}{
|
||||
{[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
|
||||
{[]string{"a"}, map[string]interface{}{"a": 42}, 42},
|
||||
{[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
|
||||
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||
if err != nil {
|
||||
t.Fatalf("lookupMapPath failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("+got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupMapPathInvalid(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Path []string
|
||||
Root map[string]interface{}
|
||||
}{
|
||||
{nil, nil},
|
||||
{[]string{"a"}, nil},
|
||||
{[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
|
||||
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||
if err == nil {
|
||||
t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRoomMemberCountCondition(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Input string
|
||||
WantTrue []int
|
||||
WantFalse []int
|
||||
}{
|
||||
{"1", []int{1}, []int{0, 2}},
|
||||
{"==1", []int{1}, []int{0, 2}},
|
||||
{"<1", []int{0}, []int{1, 2}},
|
||||
{"<=1", []int{0, 1}, []int{2}},
|
||||
{">1", []int{2}, []int{0, 1}},
|
||||
{">=42", []int{42, 43}, []int{41}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Input, func(t *testing.T) {
|
||||
got, err := parseRoomMemberCountCondition(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
|
||||
}
|
||||
for _, v := range tst.WantTrue {
|
||||
if !got(v) {
|
||||
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
|
||||
}
|
||||
}
|
||||
for _, v := range tst.WantFalse {
|
||||
if got(v) {
|
||||
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
85
internal/pushrules/validate.go
Normal file
85
internal/pushrules/validate.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// ValidateRule checks the rule for errors. These follow from Sytests
|
||||
// and the specification.
|
||||
func ValidateRule(kind Kind, rule *Rule) []error {
|
||||
var errs []error
|
||||
|
||||
if !validRuleIDRE.MatchString(rule.RuleID) {
|
||||
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
|
||||
}
|
||||
|
||||
if len(rule.Actions) == 0 {
|
||||
errs = append(errs, fmt.Errorf("missing actions"))
|
||||
}
|
||||
for _, action := range rule.Actions {
|
||||
errs = append(errs, validateAction(action)...)
|
||||
}
|
||||
|
||||
for _, cond := range rule.Conditions {
|
||||
errs = append(errs, validateCondition(cond)...)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case OverrideKind, UnderrideKind:
|
||||
// The empty list is allowed, but for JSON-encoding reasons,
|
||||
// it must not be nil.
|
||||
if rule.Conditions == nil {
|
||||
errs = append(errs, fmt.Errorf("missing rule conditions"))
|
||||
}
|
||||
|
||||
case ContentKind:
|
||||
if rule.Pattern == "" {
|
||||
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
||||
}
|
||||
|
||||
case RoomKind, SenderKind:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// validRuleIDRE is a regexp for valid IDs.
|
||||
//
|
||||
// TODO: the specification doesn't seem to say what the rule ID syntax
|
||||
// is. A Sytest fails if it contains a backslash.
|
||||
var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
|
||||
|
||||
// validateAction returns issues with an Action.
|
||||
func validateAction(action *Action) []error {
|
||||
var errs []error
|
||||
|
||||
switch action.Kind {
|
||||
case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// validateCondition returns issues with a Condition.
|
||||
func validateCondition(cond *Condition) []error {
|
||||
var errs []error
|
||||
|
||||
switch cond.Kind {
|
||||
case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
163
internal/pushrules/validate_test.go
Normal file
163
internal/pushrules/validate_test.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateRuleNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
|
||||
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
|
||||
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
|
||||
{"noActions", OverrideKind, Rule{}, "missing actions"},
|
||||
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
|
||||
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
|
||||
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
|
||||
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
|
||||
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRulePositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
|
||||
{"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
|
||||
{"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
|
||||
{"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
|
||||
{"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
|
||||
{"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
|
||||
{"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
|
||||
{"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateActionNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Action Action
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyKind", Action{}, "invalid rule action kind"},
|
||||
{"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateAction(&tst.Action)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateActionPositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Action Action
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateAction(&tst.Action)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConditionNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Condition Condition
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyKind", Condition{}, "invalid rule condition kind"},
|
||||
{"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateCondition(&tst.Condition)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConditionPositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Condition Condition
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateCondition(&tst.Condition)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -163,6 +163,7 @@ type StatementList []struct {
|
|||
func (s StatementList) Prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
|
||||
err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
63
run-sytest.sh
Executable file
63
run-sytest.sh
Executable file
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
|
||||
# locally, the Docker image is rebuilt first.
|
||||
#
|
||||
# Logs are stored in ../sytestout/logs.
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
main() {
|
||||
local tag=buster
|
||||
local base_image=debian:$tag
|
||||
local runargs=()
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ -d ../sytest ]; then
|
||||
local tmpdir
|
||||
tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
|
||||
trap "rm -r '$tmpdir'" EXIT
|
||||
|
||||
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
|
||||
echo "Re-building ../sytest Docker images..."
|
||||
|
||||
local status
|
||||
(
|
||||
cd ../sytest
|
||||
|
||||
docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
|
||||
docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
|
||||
) &>"$tmpdir/buildlog" || status=$?
|
||||
if (( status != 0 )); then
|
||||
# Docker is very verbose, and we don't really care about
|
||||
# building SyTest. So we accumulate and only output on
|
||||
# failure.
|
||||
cat "$tmpdir/buildlog" >&2
|
||||
return $status
|
||||
fi
|
||||
fi
|
||||
|
||||
runargs+=( -v "$PWD/../sytest:/sytest:ro" )
|
||||
fi
|
||||
if [ -n "$SYTEST_POSTGRES" ]; then
|
||||
runargs+=( -e POSTGRES=1 )
|
||||
fi
|
||||
|
||||
local sytestout=$PWD/../sytestout
|
||||
mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
|
||||
docker run \
|
||||
--rm \
|
||||
--name "sytest-dendrite-${LOGNAME}" \
|
||||
-e LOGS_USER=$(id -u) \
|
||||
-e LOGS_GROUP=$(id -g) \
|
||||
-v "$PWD:/src/:ro" \
|
||||
-v "$sytestout/logs:/logs/" \
|
||||
-v "$sytestout/cache/go-build:/root/.cache/go-build" \
|
||||
-v "$sytestout/cache/go-pkg:/gopath/pkg" \
|
||||
"${runargs[@]}" \
|
||||
matrixdotorg/sytest-dendrite:latest "$@"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
|
|
@ -30,6 +30,7 @@ import (
|
|||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/atomic"
|
||||
|
|
@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
|||
return f
|
||||
}
|
||||
|
||||
// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
|
||||
func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
||||
return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
|
||||
}
|
||||
|
||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||
|
|
|
|||
|
|
@ -205,6 +205,11 @@ user_api:
|
|||
max_open_conns: 100
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
pusher_database:
|
||||
connection_string: file:pushserver.db
|
||||
max_open_conns: 100
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
tracing:
|
||||
enabled: false
|
||||
jaeger:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ type UserAPI struct {
|
|||
// The length of time an OpenID token is condidered valid in milliseconds
|
||||
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
||||
|
||||
// Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
|
||||
PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
|
||||
|
||||
// The Account database stores the login details and account information
|
||||
// for local users. It is accessed by the UserAPI.
|
||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ var (
|
|||
OutputKeyChangeEvent = "OutputKeyChangeEvent"
|
||||
OutputTypingEvent = "OutputTypingEvent"
|
||||
OutputClientData = "OutputClientData"
|
||||
OutputNotificationData = "OutputNotificationData"
|
||||
OutputReceiptEvent = "OutputReceiptEvent"
|
||||
OutputStreamEvent = "OutputStreamEvent"
|
||||
OutputReadUpdate = "OutputReadUpdate"
|
||||
)
|
||||
|
||||
var streams = []*nats.StreamConfig{
|
||||
|
|
@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
|
|||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputNotificationData,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputStreamEvent,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputReadUpdate,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
|
|||
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
|
||||
m.FedClient, m.RoomserverAPI,
|
||||
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
|
||||
m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,
|
||||
&m.Config.MSCs,
|
||||
m.FederationAPI, m.UserAPI, m.KeyAPI,
|
||||
m.ExtPublicRoomsProvider, &m.Config.MSCs,
|
||||
)
|
||||
federationapi.AddPublicRoutes(
|
||||
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ package consumers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
|
@ -24,21 +26,26 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputClientDataConsumer consumes events that originated in the client API server.
|
||||
type OutputClientDataConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
notifier *notifier.Notifier
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
notifier *notifier.Notifier
|
||||
serverName gomatrixserverlib.ServerName
|
||||
producer *producers.UserAPIReadProducer
|
||||
}
|
||||
|
||||
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
|
||||
|
|
@ -49,15 +56,18 @@ func NewOutputClientDataConsumer(
|
|||
store storage.Database,
|
||||
notifier *notifier.Notifier,
|
||||
stream types.StreamProvider,
|
||||
producer *producers.UserAPIReadProducer,
|
||||
) *OutputClientDataConsumer {
|
||||
return &OutputClientDataConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
|
||||
db: store,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
|
||||
db: store,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
serverName: cfg.Matrix.ServerName,
|
||||
producer: producer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,8 +110,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
|||
}).Panicf("could not save account data")
|
||||
}
|
||||
|
||||
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
|
||||
log.WithError(err).WithFields(logrus.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": output.RoomID,
|
||||
}).Errorf("Failed to generate read update")
|
||||
sentry.CaptureException(err)
|
||||
return false
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
|
||||
if output.Type != "m.fully_read" || output.ReadMarker == nil {
|
||||
return nil
|
||||
}
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if serverName != s.serverName {
|
||||
return nil
|
||||
}
|
||||
var readPos types.StreamPosition
|
||||
var fullyReadPos types.StreamPosition
|
||||
if output.ReadMarker.Read != "" {
|
||||
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
|
||||
}
|
||||
}
|
||||
if output.ReadMarker.FullyRead != "" {
|
||||
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
|
||||
}
|
||||
}
|
||||
if readPos > 0 || fullyReadPos > 0 {
|
||||
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
|
||||
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ package consumers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/eduserver/api"
|
||||
|
|
@ -24,21 +26,26 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputReceiptEventConsumer consumes events that originated in the EDU server.
|
||||
type OutputReceiptEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
notifier *notifier.Notifier
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
notifier *notifier.Notifier
|
||||
serverName gomatrixserverlib.ServerName
|
||||
producer *producers.UserAPIReadProducer
|
||||
}
|
||||
|
||||
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
|
||||
|
|
@ -50,15 +57,18 @@ func NewOutputReceiptEventConsumer(
|
|||
store storage.Database,
|
||||
notifier *notifier.Notifier,
|
||||
stream types.StreamProvider,
|
||||
producer *producers.UserAPIReadProducer,
|
||||
) *OutputReceiptEventConsumer {
|
||||
return &OutputReceiptEventConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
|
||||
db: store,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
|
||||
db: store,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
serverName: cfg.Matrix.ServerName,
|
||||
producer: producer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -92,8 +102,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
|
|||
return true
|
||||
}
|
||||
|
||||
if err = s.sendReadUpdate(ctx, output); err != nil {
|
||||
log.WithError(err).WithFields(logrus.Fields{
|
||||
"user_id": output.UserID,
|
||||
"room_id": output.RoomID,
|
||||
}).Errorf("Failed to generate read update")
|
||||
sentry.CaptureException(err)
|
||||
return false
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error {
|
||||
if output.Type != "m.read" {
|
||||
return nil
|
||||
}
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if serverName != s.serverName {
|
||||
return nil
|
||||
}
|
||||
var readPos types.StreamPosition
|
||||
if output.EventID != "" {
|
||||
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
|
||||
}
|
||||
}
|
||||
if readPos > 0 {
|
||||
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
|
||||
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct {
|
|||
pduStream types.StreamProvider
|
||||
inviteStream types.StreamProvider
|
||||
notifier *notifier.Notifier
|
||||
producer *producers.UserAPIStreamEventProducer
|
||||
}
|
||||
|
||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
||||
|
|
@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer(
|
|||
pduStream types.StreamProvider,
|
||||
inviteStream types.StreamProvider,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
producer *producers.UserAPIStreamEventProducer,
|
||||
) *OutputRoomEventConsumer {
|
||||
return &OutputRoomEventConsumer{
|
||||
ctx: process.Context(),
|
||||
|
|
@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer(
|
|||
pduStream: pduStream,
|
||||
inviteStream: inviteStream,
|
||||
rsAPI: rsAPI,
|
||||
producer: producer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
|
||||
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
|
||||
sentry.CaptureException(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
|
||||
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
|
||||
sentry.CaptureException(err)
|
||||
|
|
|
|||
110
syncapi/consumers/userapi.go
Normal file
110
syncapi/consumers/userapi.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputNotificationDataConsumer consumes events that originated in
|
||||
// the Push server.
|
||||
type OutputNotificationDataConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
notifier *notifier.Notifier
|
||||
stream types.StreamProvider
|
||||
}
|
||||
|
||||
// NewOutputNotificationDataConsumer creates a new consumer. Call
|
||||
// Start() to begin consuming.
|
||||
func NewOutputNotificationDataConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.SyncAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
notifier *notifier.Notifier,
|
||||
stream types.StreamProvider,
|
||||
) *OutputNotificationDataConsumer {
|
||||
s := &OutputNotificationDataConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"),
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
|
||||
db: store,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Start starts consumption.
|
||||
func (s *OutputNotificationDataConsumer) Start() error {
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called when the Sync server receives a new event from
|
||||
// the push server. It is not safe for this function to be called from
|
||||
// multiple goroutines, or else the sync stream position may race and
|
||||
// be incorrectly calculated.
|
||||
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
userID := string(msg.Header.Get(jetstream.UserID))
|
||||
|
||||
// Parse out the event JSON
|
||||
var data eventutil.NotificationData
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure")
|
||||
return true
|
||||
}
|
||||
|
||||
streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": data.RoomID,
|
||||
}).WithError(err).Error("Could not save notification counts")
|
||||
return false
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos})
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": data.RoomID,
|
||||
"streamPos": streamPos,
|
||||
}).Trace("Received notification data from user API")
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite(
|
|||
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewNotificationData(
|
||||
userID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers([]string{userID}, nil, n.currPos)
|
||||
}
|
||||
|
||||
// GetListener returns a UserStreamListener that can be used to wait for
|
||||
// updates for a user. Must be closed.
|
||||
// notify for anything before sincePos
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) {
|
|||
go func() {
|
||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
|
||||
if err != nil {
|
||||
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
||||
t.Errorf("TestNewInviteEventForUser error: %v", err)
|
||||
}
|
||||
mustEqualPositions(t, pos, syncPositionNewEDU)
|
||||
wg.Done()
|
||||
|
|
|
|||
62
syncapi/producers/userapi_readupdate.go
Normal file
62
syncapi/producers/userapi_readupdate.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package producers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// UserAPIProducer produces events for the user API server to consume
|
||||
type UserAPIReadProducer struct {
|
||||
Topic string
|
||||
JetStream nats.JetStreamContext
|
||||
}
|
||||
|
||||
// SendData sends account data to the user API server
|
||||
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, userID)
|
||||
m.Header.Set(jetstream.RoomID, roomID)
|
||||
|
||||
data := types.ReadUpdate{
|
||||
UserID: userID,
|
||||
RoomID: roomID,
|
||||
Read: readPos,
|
||||
FullyRead: fullyReadPos,
|
||||
}
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": roomID,
|
||||
"read_pos": readPos,
|
||||
"fully_read_pos": fullyReadPos,
|
||||
}).Tracef("Producing to topic '%s'", p.Topic)
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
return err
|
||||
}
|
||||
60
syncapi/producers/userapi_streamevent.go
Normal file
60
syncapi/producers/userapi_streamevent.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package producers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// UserAPIProducer produces events for the user API server to consume
|
||||
type UserAPIStreamEventProducer struct {
|
||||
Topic string
|
||||
JetStream nats.JetStreamContext
|
||||
}
|
||||
|
||||
// SendData sends account data to the user API server
|
||||
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.RoomID, roomID)
|
||||
|
||||
data := types.StreamedEvent{
|
||||
Event: event,
|
||||
StreamPosition: pos,
|
||||
}
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"room_id": roomID,
|
||||
"event_id": event.EventID(),
|
||||
"event_type": event.Type(),
|
||||
"stream_pos": pos,
|
||||
}).Tracef("Producing to topic '%s'", p.Topic)
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
return err
|
||||
}
|
||||
|
|
@ -18,6 +18,7 @@ import (
|
|||
"context"
|
||||
|
||||
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
|
@ -31,6 +32,7 @@ type Database interface {
|
|||
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
|
||||
|
||||
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||
|
|
@ -138,6 +140,12 @@ type Database interface {
|
|||
// GetRoomReceipts gets all receipts for a given roomID
|
||||
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
|
||||
|
||||
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
||||
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||
|
||||
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
|
||||
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
|
||||
|
||||
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
||||
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
||||
|
|
|
|||
108
syncapi/storage/postgres/notification_data_table.go
Normal file
108
syncapi/storage/postgres/notification_data_table.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
|
||||
_, err := db.Exec(notificationDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := ¬ificationDataStatements{}
|
||||
return r, sqlutil.StatementList{
|
||||
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
|
||||
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
|
||||
{&r.selectMaxID, selectMaxNotificationIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type notificationDataStatements struct {
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCounts *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
notification_count BIGINT NOT NULL DEFAULT 0,
|
||||
highlight_count BIGINT NOT NULL DEFAULT 0,
|
||||
CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id)
|
||||
);`
|
||||
|
||||
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
|
||||
(user_id, room_id, notification_count, highlight_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, room_id)
|
||||
DO UPDATE SET notification_count = $3, highlight_count = $4
|
||||
RETURNING id`
|
||||
|
||||
const selectUserUnreadNotificationCountsSQL = `SELECT
|
||||
id, room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
id BETWEEN $2 + 1 AND $3`
|
||||
|
||||
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
|
||||
|
||||
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
|
||||
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
|
||||
|
||||
roomCounts := map[string]*eventutil.NotificationData{}
|
||||
for rows.Next() {
|
||||
var id types.StreamPosition
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
|
||||
if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roomCounts[roomID] = &eventutil.NotificationData{
|
||||
RoomID: roomID,
|
||||
UnreadNotificationCount: notificationCount,
|
||||
UnreadHighlightCount: highlightCount,
|
||||
}
|
||||
}
|
||||
return roomCounts, rows.Err()
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
|
||||
var id int64
|
||||
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/shared"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// SyncServerDatasource represents a sync server datasource which manages
|
||||
|
|
@ -34,12 +33,11 @@ type SyncServerDatasource struct {
|
|||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
sqlutil.PartitionOffsetStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*SyncServerDatasource, error) {
|
||||
d := SyncServerDatasource{serverName: serverName}
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
||||
d := SyncServerDatasource{}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -92,7 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notificationData, err := NewPostgresNotificationDataTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadFixSequences(m)
|
||||
deltas.LoadRemoveSendToDeviceSentColumn(m)
|
||||
|
|
@ -113,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
SendToDevice: sendToDevice,
|
||||
Receipts: receipts,
|
||||
Memberships: memberships,
|
||||
NotificationData: notificationData,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ type Database struct {
|
|||
Filter tables.Filter
|
||||
Receipts tables.Receipts
|
||||
Memberships tables.Memberships
|
||||
NotificationData tables.NotificationData
|
||||
}
|
||||
|
||||
func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
|
||||
|
|
@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
|
|||
return types.StreamPosition(id), nil
|
||||
}
|
||||
|
||||
func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
|
||||
id, err := d.NotificationData.SelectMaxID(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
|
||||
}
|
||||
return types.StreamPosition(id), nil
|
||||
}
|
||||
|
||||
func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
|
||||
}
|
||||
|
|
@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream
|
|||
return receipts, err
|
||||
}
|
||||
|
||||
func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
|
||||
}
|
||||
|
||||
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
|
||||
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
|
||||
}
|
||||
|
|
|
|||
108
syncapi/storage/sqlite3/notification_data_table.go
Normal file
108
syncapi/storage/sqlite3/notification_data_table.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
|
||||
_, err := db.Exec(notificationDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := ¬ificationDataStatements{}
|
||||
return r, sqlutil.StatementList{
|
||||
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
|
||||
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
|
||||
{&r.selectMaxID, selectMaxNotificationIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type notificationDataStatements struct {
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCounts *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
notification_count BIGINT NOT NULL DEFAULT 0,
|
||||
highlight_count BIGINT NOT NULL DEFAULT 0,
|
||||
CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id)
|
||||
);`
|
||||
|
||||
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
|
||||
(user_id, room_id, notification_count, highlight_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, room_id)
|
||||
DO UPDATE SET notification_count = $3, highlight_count = $4
|
||||
RETURNING id`
|
||||
|
||||
const selectUserUnreadNotificationCountsSQL = `SELECT
|
||||
id, room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
id BETWEEN $2 + 1 AND $3`
|
||||
|
||||
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
|
||||
|
||||
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
|
||||
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
|
||||
|
||||
roomCounts := map[string]*eventutil.NotificationData{}
|
||||
for rows.Next() {
|
||||
var id types.StreamPosition
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
|
||||
if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roomCounts[roomID] = &eventutil.NotificationData{
|
||||
RoomID: roomID,
|
||||
UnreadNotificationCount: notificationCount,
|
||||
UnreadHighlightCount: highlightCount,
|
||||
}
|
||||
}
|
||||
return roomCounts, rows.Err()
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
|
||||
var id int64
|
||||
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
|
@ -62,16 +62,19 @@ const selectEventsSQL = "" +
|
|||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectRecentEventsForSyncSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectEarlyEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectMaxEventIDSQL = "" +
|
||||
|
|
@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" +
|
|||
" FROM syncapi_output_room_events" +
|
||||
" WHERE (id > $1 AND id <= $2)" +
|
||||
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const deleteEventsForRoomSQL = "" +
|
||||
|
|
@ -95,10 +99,12 @@ const selectContextEventSQL = "" +
|
|||
|
||||
const selectContextBeforeEventSQL = "" +
|
||||
"SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectContextAfterEventSQL = "" +
|
||||
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
|
||||
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
type outputRoomEventsStatements struct {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// SyncServerDatasource represents a sync server datasource which manages
|
||||
|
|
@ -32,14 +31,13 @@ type SyncServerDatasource struct {
|
|||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
sqlutil.PartitionOffsetStatements
|
||||
streamID streamIDStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
streamID streamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
// nolint: gocyclo
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*SyncServerDatasource, error) {
|
||||
d := SyncServerDatasource{serverName: serverName}
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
||||
var d SyncServerDatasource
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -102,8 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
notificationData, err := NewSqliteNotificationDataTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadFixSequences(m)
|
||||
deltas.LoadRemoveSendToDeviceSentColumn(m)
|
||||
|
|
@ -124,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
|
|||
SendToDevice: sendToDevice,
|
||||
Receipts: receipts,
|
||||
Memberships: memberships,
|
||||
NotificationData: notificationData,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ import (
|
|||
func NewSyncServerDatasource(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(dbProperties, serverName)
|
||||
return postgres.NewDatabase(dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
func NewSyncServerDatasource(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -171,3 +172,9 @@ type Memberships interface {
|
|||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
||||
}
|
||||
|
||||
type NotificationData interface {
|
||||
UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||
SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
|
||||
SelectMaxID(ctx context.Context) (int64, error)
|
||||
}
|
||||
|
|
|
|||
55
syncapi/streams/stream_notificationdata.go
Normal file
55
syncapi/streams/stream_notificationdata.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package streams
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
type NotificationDataStreamProvider struct {
|
||||
StreamProvider
|
||||
}
|
||||
|
||||
func (p *NotificationDataStreamProvider) Setup() {
|
||||
p.StreamProvider.Setup()
|
||||
|
||||
id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.latest = id
|
||||
}
|
||||
|
||||
func (p *NotificationDataStreamProvider) CompleteSync(
|
||||
ctx context.Context,
|
||||
req *types.SyncRequest,
|
||||
) types.StreamPosition {
|
||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
||||
}
|
||||
|
||||
func (p *NotificationDataStreamProvider) IncrementalSync(
|
||||
ctx context.Context,
|
||||
req *types.SyncRequest,
|
||||
from, to types.StreamPosition,
|
||||
) types.StreamPosition {
|
||||
// We want counts for all possible rooms, so always start from zero.
|
||||
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
|
||||
return from
|
||||
}
|
||||
|
||||
// We're merely decorating existing rooms. Note that the Join map
|
||||
// values are not pointers.
|
||||
for roomID, jr := range req.Response.Rooms.Join {
|
||||
counts := countsByRoom[roomID]
|
||||
if counts == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount
|
||||
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount
|
||||
req.Response.Rooms.Join[roomID] = jr
|
||||
}
|
||||
return to
|
||||
}
|
||||
|
|
@ -12,13 +12,14 @@ import (
|
|||
)
|
||||
|
||||
type Streams struct {
|
||||
PDUStreamProvider types.StreamProvider
|
||||
TypingStreamProvider types.StreamProvider
|
||||
ReceiptStreamProvider types.StreamProvider
|
||||
InviteStreamProvider types.StreamProvider
|
||||
SendToDeviceStreamProvider types.StreamProvider
|
||||
AccountDataStreamProvider types.StreamProvider
|
||||
DeviceListStreamProvider types.StreamProvider
|
||||
PDUStreamProvider types.StreamProvider
|
||||
TypingStreamProvider types.StreamProvider
|
||||
ReceiptStreamProvider types.StreamProvider
|
||||
InviteStreamProvider types.StreamProvider
|
||||
SendToDeviceStreamProvider types.StreamProvider
|
||||
AccountDataStreamProvider types.StreamProvider
|
||||
DeviceListStreamProvider types.StreamProvider
|
||||
NotificationDataStreamProvider types.StreamProvider
|
||||
}
|
||||
|
||||
func NewSyncStreamProviders(
|
||||
|
|
@ -47,6 +48,9 @@ func NewSyncStreamProviders(
|
|||
StreamProvider: StreamProvider{DB: d},
|
||||
userAPI: userAPI,
|
||||
},
|
||||
NotificationDataStreamProvider: &NotificationDataStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
},
|
||||
DeviceListStreamProvider: &DeviceListStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
rsAPI: rsAPI,
|
||||
|
|
@ -60,6 +64,7 @@ func NewSyncStreamProviders(
|
|||
streams.InviteStreamProvider.Setup()
|
||||
streams.SendToDeviceStreamProvider.Setup()
|
||||
streams.AccountDataStreamProvider.Setup()
|
||||
streams.NotificationDataStreamProvider.Setup()
|
||||
streams.DeviceListStreamProvider.Setup()
|
||||
|
||||
return streams
|
||||
|
|
@ -67,12 +72,13 @@ func NewSyncStreamProviders(
|
|||
|
||||
func (s *Streams) Latest(ctx context.Context) types.StreamingToken {
|
||||
return types.StreamingToken{
|
||||
PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
|
||||
TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
|
||||
ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
|
||||
InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
|
||||
SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
|
||||
AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
|
||||
DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx),
|
||||
PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
|
||||
TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
|
||||
ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
|
||||
InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
|
||||
SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
|
||||
AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
|
||||
NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx),
|
||||
DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
|
||||
}
|
||||
} else {
|
||||
syncReq.Log.Debugln("Responding to sync immediately")
|
||||
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
|
||||
}
|
||||
|
||||
if syncReq.Since.IsEmpty() {
|
||||
|
|
@ -214,6 +214,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
|
||||
syncReq.Context, syncReq,
|
||||
),
|
||||
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
|
||||
syncReq.Context, syncReq,
|
||||
),
|
||||
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
|
||||
syncReq.Context, syncReq,
|
||||
),
|
||||
|
|
@ -245,6 +248,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
syncReq.Context, syncReq,
|
||||
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
|
||||
),
|
||||
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
|
||||
syncReq.Context, syncReq,
|
||||
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
|
||||
),
|
||||
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
|
||||
syncReq.Context, syncReq,
|
||||
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/syncapi/consumers"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||
"github.com/matrix-org/dendrite/syncapi/routing"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||
|
|
@ -64,6 +65,18 @@ func AddPublicRoutes(
|
|||
|
||||
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier)
|
||||
|
||||
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
|
||||
JetStream: js,
|
||||
Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
|
||||
}
|
||||
|
||||
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
|
||||
JetStream: js,
|
||||
Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
|
||||
}
|
||||
|
||||
_ = userAPIReadUpdateProducer
|
||||
|
||||
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
|
||||
process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
js, keyAPI, rsAPI, syncDB, notifier,
|
||||
|
|
@ -75,7 +88,7 @@ func AddPublicRoutes(
|
|||
|
||||
roomConsumer := consumers.NewOutputRoomEventConsumer(
|
||||
process, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
|
||||
streams.InviteStreamProvider, rsAPI,
|
||||
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer,
|
||||
)
|
||||
if err = roomConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start room server consumer")
|
||||
|
|
@ -83,11 +96,19 @@ func AddPublicRoutes(
|
|||
|
||||
clientConsumer := consumers.NewOutputClientDataConsumer(
|
||||
process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
|
||||
userAPIReadUpdateProducer,
|
||||
)
|
||||
if err = clientConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start client data consumer")
|
||||
}
|
||||
|
||||
notificationConsumer := consumers.NewOutputNotificationDataConsumer(
|
||||
process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider,
|
||||
)
|
||||
if err = notificationConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start notification data consumer")
|
||||
}
|
||||
|
||||
typingConsumer := consumers.NewOutputTypingEventConsumer(
|
||||
process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider,
|
||||
)
|
||||
|
|
@ -104,6 +125,7 @@ func AddPublicRoutes(
|
|||
|
||||
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
|
||||
process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
|
||||
userAPIReadUpdateProducer,
|
||||
)
|
||||
if err = receiptConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
||||
|
|
|
|||
|
|
@ -95,13 +95,14 @@ const (
|
|||
)
|
||||
|
||||
type StreamingToken struct {
|
||||
PDUPosition StreamPosition
|
||||
TypingPosition StreamPosition
|
||||
ReceiptPosition StreamPosition
|
||||
SendToDevicePosition StreamPosition
|
||||
InvitePosition StreamPosition
|
||||
AccountDataPosition StreamPosition
|
||||
DeviceListPosition StreamPosition
|
||||
PDUPosition StreamPosition
|
||||
TypingPosition StreamPosition
|
||||
ReceiptPosition StreamPosition
|
||||
SendToDevicePosition StreamPosition
|
||||
InvitePosition StreamPosition
|
||||
AccountDataPosition StreamPosition
|
||||
DeviceListPosition StreamPosition
|
||||
NotificationDataPosition StreamPosition
|
||||
}
|
||||
|
||||
// This will be used as a fallback by json.Marshal.
|
||||
|
|
@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
|
|||
|
||||
func (t StreamingToken) String() string {
|
||||
posStr := fmt.Sprintf(
|
||||
"s%d_%d_%d_%d_%d_%d_%d",
|
||||
"s%d_%d_%d_%d_%d_%d_%d_%d",
|
||||
t.PDUPosition, t.TypingPosition,
|
||||
t.ReceiptPosition, t.SendToDevicePosition,
|
||||
t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition,
|
||||
t.InvitePosition, t.AccountDataPosition,
|
||||
t.DeviceListPosition, t.NotificationDataPosition,
|
||||
)
|
||||
return posStr
|
||||
}
|
||||
|
|
@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
|
|||
return true
|
||||
case t.DeviceListPosition > other.DeviceListPosition:
|
||||
return true
|
||||
case t.NotificationDataPosition > other.NotificationDataPosition:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *StreamingToken) IsEmpty() bool {
|
||||
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0
|
||||
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0
|
||||
}
|
||||
|
||||
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
|
||||
|
|
@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
|
|||
if other.DeviceListPosition > t.DeviceListPosition {
|
||||
t.DeviceListPosition = other.DeviceListPosition
|
||||
}
|
||||
if other.NotificationDataPosition > t.NotificationDataPosition {
|
||||
t.NotificationDataPosition = other.NotificationDataPosition
|
||||
}
|
||||
}
|
||||
|
||||
type TopologyToken struct {
|
||||
|
|
@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
|||
// s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions
|
||||
tok = strings.Split(tok, ".")[0]
|
||||
parts := strings.Split(tok[1:], "_")
|
||||
var positions [7]StreamPosition
|
||||
var positions [8]StreamPosition
|
||||
for i, p := range parts {
|
||||
if i >= len(positions) {
|
||||
break
|
||||
|
|
@ -291,13 +298,14 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
|||
positions[i] = StreamPosition(pos)
|
||||
}
|
||||
token = StreamingToken{
|
||||
PDUPosition: positions[0],
|
||||
TypingPosition: positions[1],
|
||||
ReceiptPosition: positions[2],
|
||||
SendToDevicePosition: positions[3],
|
||||
InvitePosition: positions[4],
|
||||
AccountDataPosition: positions[5],
|
||||
DeviceListPosition: positions[6],
|
||||
PDUPosition: positions[0],
|
||||
TypingPosition: positions[1],
|
||||
ReceiptPosition: positions[2],
|
||||
SendToDevicePosition: positions[3],
|
||||
InvitePosition: positions[4],
|
||||
AccountDataPosition: positions[5],
|
||||
DeviceListPosition: positions[6],
|
||||
NotificationDataPosition: positions[7],
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
|
@ -383,6 +391,10 @@ type JoinResponse struct {
|
|||
AccountData struct {
|
||||
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
||||
} `json:"account_data"`
|
||||
UnreadNotifications struct {
|
||||
HighlightCount int `json:"highlight_count"`
|
||||
NotificationCount int `json:"notification_count"`
|
||||
} `json:"unread_notifications"`
|
||||
}
|
||||
|
||||
// NewJoinResponse creates an empty response with initialised arrays.
|
||||
|
|
@ -462,3 +474,16 @@ type Peek struct {
|
|||
New bool
|
||||
Deleted bool
|
||||
}
|
||||
|
||||
type ReadUpdate struct {
|
||||
UserID string `json:"user_id"`
|
||||
RoomID string `json:"room_id"`
|
||||
Read StreamPosition `json:"read,omitempty"`
|
||||
FullyRead StreamPosition `json:"fully_read,omitempty"`
|
||||
}
|
||||
|
||||
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
|
||||
type StreamedEvent struct {
|
||||
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
|
||||
StreamPosition StreamPosition `json:"stream_position"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ import (
|
|||
|
||||
func TestSyncTokens(t *testing.T) {
|
||||
shouldPass := map[string]string{
|
||||
"s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(),
|
||||
"s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(),
|
||||
"s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(),
|
||||
"t3_1": TopologyToken{3, 1}.String(),
|
||||
"s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(),
|
||||
"s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(),
|
||||
"s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(),
|
||||
"t3_1": TopologyToken{3, 1}.String(),
|
||||
}
|
||||
|
||||
for a, b := range shouldPass {
|
||||
|
|
|
|||
|
|
@ -24,9 +24,14 @@ Local device key changes get to remote servers with correct prev_id
|
|||
|
||||
# Flakey
|
||||
Local device key changes appear in /keys/changes
|
||||
/context/ with lazy_load_members filter works
|
||||
|
||||
# we don't support groups
|
||||
Remove group category
|
||||
Remove group role
|
||||
|
||||
# Flakey
|
||||
AS-ghosted users can use rooms themselves
|
||||
|
||||
# Flakey, need additional investigation
|
||||
Messages that notify from another user increment notification_count
|
||||
Messages that highlight from another user increment unread highlight count
|
||||
|
|
|
|||
|
|
@ -339,17 +339,17 @@ Existing members see new members' join events
|
|||
Inbound federation can receive events
|
||||
Inbound federation can receive redacted events
|
||||
Can logout current device
|
||||
Can send a message directly to a device using PUT /sendToDevice
|
||||
Can recv a device message using /sync
|
||||
Can recv device messages until they are acknowledged
|
||||
Device messages with the same txn_id are deduplicated
|
||||
Device messages wake up /sync
|
||||
Can recv device messages over federation
|
||||
Device messages over federation wake up /sync
|
||||
Can send messages with a wildcard device id
|
||||
Can send messages with a wildcard device id to two devices
|
||||
Wildcard device messages wake up /sync
|
||||
Wildcard device messages over federation wake up /sync
|
||||
Can send a message directly to a device using PUT /sendToDevice
|
||||
Can recv a device message using /sync
|
||||
Can recv device messages until they are acknowledged
|
||||
Device messages with the same txn_id are deduplicated
|
||||
Device messages wake up /sync
|
||||
Can recv device messages over federation
|
||||
Device messages over federation wake up /sync
|
||||
Can send messages with a wildcard device id
|
||||
Can send messages with a wildcard device id to two devices
|
||||
Wildcard device messages wake up /sync
|
||||
Wildcard device messages over federation wake up /sync
|
||||
Can send a to-device message to two users which both receive it using /sync
|
||||
User can create and send/receive messages in a room with version 6
|
||||
local user can join room with version 6
|
||||
|
|
@ -477,7 +477,7 @@ Federation key API can act as a notary server via a GET request
|
|||
Inbound /make_join rejects attempts to join rooms where all users have left
|
||||
Inbound federation rejects invites which include invalid JSON for room version 6
|
||||
Inbound federation rejects invite rejections which include invalid JSON for room version 6
|
||||
GET /capabilities is present and well formed for registered user
|
||||
GET /capabilities is present and well formed for registered user
|
||||
m.room.history_visibility == "joined" allows/forbids appropriately for Guest users
|
||||
m.room.history_visibility == "joined" allows/forbids appropriately for Real users
|
||||
POST rejects invalid utf-8 in JSON
|
||||
|
|
@ -588,6 +588,59 @@ User can invite remote user to room with version 9
|
|||
Remote user can backfill in a room with version 9
|
||||
Can reject invites over federation for rooms with version 9
|
||||
Can receive redactions from regular users over federation in room version 9
|
||||
Pushers created with a different access token are deleted on password change
|
||||
Pushers created with a the same access token are not deleted on password change
|
||||
Can fetch a user's pushers
|
||||
Can add global push rule for room
|
||||
Can add global push rule for sender
|
||||
Can add global push rule for content
|
||||
Can add global push rule for override
|
||||
Can add global push rule for underride
|
||||
Can add global push rule for content
|
||||
New rules appear before old rules by default
|
||||
Can add global push rule before an existing rule
|
||||
Can add global push rule after an existing rule
|
||||
Can delete a push rule
|
||||
Can disable a push rule
|
||||
Adding the same push rule twice is idempotent
|
||||
Can change the actions of default rules
|
||||
Can change the actions of a user specified rule
|
||||
Adding a push rule wakes up an incremental /sync
|
||||
Disabling a push rule wakes up an incremental /sync
|
||||
Enabling a push rule wakes up an incremental /sync
|
||||
Setting actions for a push rule wakes up an incremental /sync
|
||||
Can enable/disable default rules
|
||||
Trying to add push rule with missing template fails with 400
|
||||
Trying to add push rule with missing rule_id fails with 400
|
||||
Trying to add push rule with empty rule_id fails with 400
|
||||
Trying to add push rule with invalid template fails with 400
|
||||
Trying to add push rule with rule_id with slashes fails with 400
|
||||
Trying to add push rule with override rule without conditions fails with 400
|
||||
Trying to add push rule with underride rule without conditions fails with 400
|
||||
Trying to add push rule with condition without kind fails with 400
|
||||
Trying to add push rule with content rule without pattern fails with 400
|
||||
Trying to add push rule with no actions fails with 400
|
||||
Trying to add push rule with invalid action fails with 400
|
||||
Trying to add push rule with invalid attr fails with 400
|
||||
Trying to add push rule with invalid value for enabled fails with 400
|
||||
Trying to get push rules with no trailing slash fails with 400
|
||||
Trying to get push rules with scope without trailing slash fails with 400
|
||||
Trying to get push rules with template without tailing slash fails with 400
|
||||
Trying to get push rules with unknown scope fails with 400
|
||||
Trying to get push rules with unknown template fails with 400
|
||||
Trying to get push rules with unknown attribute fails with 400
|
||||
Getting push rules doesn't corrupt the cache SYN-390
|
||||
Test that a message is pushed
|
||||
Invites are pushed
|
||||
Rooms with names are correctly named in pushes
|
||||
Rooms with canonical alias are correctly named in pushed
|
||||
Rooms with many users are correctly pushed
|
||||
Don't get pushed for rooms you've muted
|
||||
Rejected events are not pushed
|
||||
Test that rejected pushers are removed.
|
||||
Notifications can be viewed with GET /notifications
|
||||
Trying to add push rule with no scope fails with 400
|
||||
Trying to add push rule with invalid scope fails with 400
|
||||
Forward extremities remain so even after the next events are populated as outliers
|
||||
If a device list update goes missing, the server resyncs on the next one
|
||||
uploading self-signing key notifies over federation
|
||||
|
|
@ -607,4 +660,4 @@ registration accepts non-ascii passwords
|
|||
registration with inhibit_login inhibits login
|
||||
The operation must be consistent through an interactive authentication session
|
||||
Multiple calls to /sync should not cause 500 errors
|
||||
|
||||
/context/ with lazy_load_members filter works
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
)
|
||||
|
||||
// UserInternalAPI is the internal API for information about users and devices.
|
||||
|
|
@ -28,6 +29,7 @@ type UserInternalAPI interface {
|
|||
LoginTokenInternalAPI
|
||||
|
||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||
|
||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
|
||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||
|
|
@ -37,6 +39,10 @@ type UserInternalAPI interface {
|
|||
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
||||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
|
||||
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
|
||||
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
|
||||
|
||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||
|
|
@ -45,6 +51,9 @@ type UserInternalAPI interface {
|
|||
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
|
||||
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
||||
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
|
||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||
}
|
||||
|
||||
type PerformKeyBackupRequest struct {
|
||||
|
|
@ -425,3 +434,77 @@ const (
|
|||
// AccountTypeAppService indicates this is an appservice account
|
||||
AccountTypeAppService AccountType = 4
|
||||
)
|
||||
|
||||
type QueryPushersRequest struct {
|
||||
Localpart string
|
||||
}
|
||||
|
||||
type QueryPushersResponse struct {
|
||||
Pushers []Pusher `json:"pushers"`
|
||||
}
|
||||
|
||||
type PerformPusherSetRequest struct {
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
Append bool `json:"append"`
|
||||
}
|
||||
|
||||
type PerformPusherDeletionRequest struct {
|
||||
Localpart string
|
||||
SessionID int64
|
||||
}
|
||||
|
||||
// Pusher represents a push notification subscriber
|
||||
type Pusher struct {
|
||||
SessionID int64 `json:"session_id,omitempty"`
|
||||
PushKey string `json:"pushkey"`
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
Kind PusherKind `json:"kind"`
|
||||
AppID string `json:"app_id"`
|
||||
AppDisplayName string `json:"app_display_name"`
|
||||
DeviceDisplayName string `json:"device_display_name"`
|
||||
ProfileTag string `json:"profile_tag"`
|
||||
Language string `json:"lang"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type PusherKind string
|
||||
|
||||
const (
|
||||
EmailKind PusherKind = "email"
|
||||
HTTPKind PusherKind = "http"
|
||||
)
|
||||
|
||||
type PerformPushRulesPutRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
|
||||
}
|
||||
|
||||
type QueryPushRulesRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type QueryPushRulesResponse struct {
|
||||
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
|
||||
}
|
||||
|
||||
type QueryNotificationsRequest struct {
|
||||
Localpart string `json:"localpart"` // Required.
|
||||
From string `json:"from,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Only string `json:"only,omitempty"`
|
||||
}
|
||||
|
||||
type QueryNotificationsResponse struct {
|
||||
NextToken string `json:"next_token"`
|
||||
Notifications []*Notification `json:"notifications"` // Required.
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
Actions []*pushrules.Action `json:"actions"` // Required.
|
||||
Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
|
||||
ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
|
||||
Read bool `json:"read"` // Required.
|
||||
RoomID string `json:"room_id"` // Required.
|
||||
TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor
|
|||
util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error {
|
||||
err := t.Impl.PerformPusherSet(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error {
|
||||
err := t.Impl.PerformPusherDeletion(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error {
|
||||
err := t.Impl.PerformPushRulesPut(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
|
||||
t.Impl.QueryKeyBackup(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
|
||||
|
|
@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO
|
|||
util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error {
|
||||
err := t.Impl.QueryPushers(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error {
|
||||
err := t.Impl.QueryPushRules(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error {
|
||||
err := t.Impl.QueryNotifications(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func js(thing interface{}) string {
|
||||
b, err := json.Marshal(thing)
|
||||
|
|
|
|||
136
userapi/consumers/syncapi_readupdate.go
Normal file
136
userapi/consumers/syncapi_readupdate.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/producers"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/util"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type OutputReadUpdateConsumer struct {
|
||||
ctx context.Context
|
||||
cfg *config.UserAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
pgClient pushgateway.Client
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
topic string
|
||||
userAPI uapi.UserInternalAPI
|
||||
syncProducer *producers.SyncAPI
|
||||
}
|
||||
|
||||
func NewOutputReadUpdateConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
pgClient pushgateway.Client,
|
||||
userAPI uapi.UserInternalAPI,
|
||||
syncProducer *producers.SyncAPI,
|
||||
) *OutputReadUpdateConsumer {
|
||||
return &OutputReadUpdateConsumer{
|
||||
ctx: process.Context(),
|
||||
cfg: cfg,
|
||||
jetstream: js,
|
||||
db: store,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
|
||||
pgClient: pgClient,
|
||||
userAPI: userAPI,
|
||||
syncProducer: syncProducer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OutputReadUpdateConsumer) Start() error {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var read types.ReadUpdate
|
||||
if err := json.Unmarshal(msg.Data, &read); err != nil {
|
||||
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
|
||||
return true
|
||||
}
|
||||
if read.FullyRead == 0 && read.Read == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
userID := string(msg.Header.Get(jetstream.UserID))
|
||||
roomID := string(msg.Header.Get(jetstream.RoomID))
|
||||
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
|
||||
return true
|
||||
}
|
||||
if domain != s.ServerName {
|
||||
log.Error("userapi clientapi consumer: not a local user")
|
||||
return true
|
||||
}
|
||||
|
||||
log := log.WithFields(log.Fields{
|
||||
"room_id": roomID,
|
||||
"user_id": userID,
|
||||
})
|
||||
log.Tracef("Received read update from sync API: %#v", read)
|
||||
|
||||
if read.Read > 0 {
|
||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer")
|
||||
return false
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
|
||||
return false
|
||||
}
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if read.FullyRead > 0 {
|
||||
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
|
||||
return false
|
||||
}
|
||||
|
||||
if deleted {
|
||||
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
||||
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
|
||||
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
591
userapi/consumers/syncapi_streamevent.go
Normal file
591
userapi/consumers/syncapi_streamevent.go
Normal file
|
|
@ -0,0 +1,591 @@
|
|||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/producers"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/util"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type OutputStreamEventConsumer struct {
|
||||
ctx context.Context
|
||||
cfg *config.UserAPI
|
||||
userAPI api.UserInternalAPI
|
||||
rsAPI rsapi.RoomserverInternalAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
topic string
|
||||
pgClient pushgateway.Client
|
||||
syncProducer *producers.SyncAPI
|
||||
}
|
||||
|
||||
func NewOutputStreamEventConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
pgClient pushgateway.Client,
|
||||
userAPI api.UserInternalAPI,
|
||||
rsAPI rsapi.RoomserverInternalAPI,
|
||||
syncProducer *producers.SyncAPI,
|
||||
) *OutputStreamEventConsumer {
|
||||
return &OutputStreamEventConsumer{
|
||||
ctx: process.Context(),
|
||||
cfg: cfg,
|
||||
jetstream: js,
|
||||
db: store,
|
||||
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
|
||||
pgClient: pgClient,
|
||||
userAPI: userAPI,
|
||||
rsAPI: rsAPI,
|
||||
syncProducer: syncProducer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OutputStreamEventConsumer) Start() error {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output types.StreamedEvent
|
||||
output.Event = &gomatrixserverlib.HeaderedEvent{}
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
log.WithError(err).Errorf("userapi consumer: message parse failure")
|
||||
return true
|
||||
}
|
||||
if output.Event.Event == nil {
|
||||
log.Errorf("userapi consumer: expected event")
|
||||
return true
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": output.Event.EventID(),
|
||||
"event_type": output.Event.Type(),
|
||||
"stream_pos": output.StreamPosition,
|
||||
}).Tracef("Received message from sync API: %#v", output)
|
||||
|
||||
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": output.Event.EventID(),
|
||||
}).WithError(err).Errorf("userapi consumer: process room event failure")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
|
||||
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.localRoomMembers: %w", err)
|
||||
}
|
||||
|
||||
if event.Type() == gomatrixserverlib.MRoomMember {
|
||||
cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
|
||||
var member *localMembership
|
||||
member, err = newLocalMembership(&cevent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("newLocalMembership: %w", err)
|
||||
}
|
||||
if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName {
|
||||
// localRoomMembers only adds joined members. An invite
|
||||
// should also be pushed to the target user.
|
||||
members = append(members, member)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: run in parallel with localRoomMembers.
|
||||
roomName, err := s.roomName(ctx, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.roomName: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": event.RoomID(),
|
||||
"num_members": len(members),
|
||||
"room_size": roomSize,
|
||||
}).Tracef("Notifying members")
|
||||
|
||||
// Notification.UserIsTarget is a per-member field, so we
|
||||
// cannot group all users in a single request.
|
||||
//
|
||||
// TODO: does it have to be set? It's not required, and
|
||||
// removing it means we can send all notifications to
|
||||
// e.g. Element's Push gateway in one go.
|
||||
for _, mem := range members {
|
||||
if p, err := s.db.GetPushers(ctx, mem.Localpart); err != nil || len(p) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": mem.Localpart,
|
||||
}).WithError(err).Debugf("Unable to push to local user")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type localMembership struct {
|
||||
gomatrixserverlib.MemberContent
|
||||
UserID string
|
||||
Localpart string
|
||||
Domain gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) {
|
||||
if event.StateKey == nil {
|
||||
return nil, fmt.Errorf("missing state_key")
|
||||
}
|
||||
|
||||
var member localMembership
|
||||
if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
member.UserID = *event.StateKey
|
||||
member.Localpart = localpart
|
||||
member.Domain = domain
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
// localRoomMembers fetches the current local members of a room, and
|
||||
// the total number of members.
|
||||
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
|
||||
req := &rsapi.QueryMembershipsForRoomRequest{
|
||||
RoomID: roomID,
|
||||
JoinedOnly: true,
|
||||
}
|
||||
var res rsapi.QueryMembershipsForRoomResponse
|
||||
|
||||
// XXX: This could potentially race if the state for the event is not known yet
|
||||
// e.g. the event came over federation but we do not have the full state persisted.
|
||||
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var members []*localMembership
|
||||
var ntotal int
|
||||
for _, event := range res.JoinEvents {
|
||||
member, err := newLocalMembership(&event)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Parsing MemberContent")
|
||||
continue
|
||||
}
|
||||
if member.Membership != gomatrixserverlib.Join {
|
||||
continue
|
||||
}
|
||||
if member.Domain != s.cfg.Matrix.ServerName {
|
||||
continue
|
||||
}
|
||||
|
||||
ntotal++
|
||||
members = append(members, member)
|
||||
}
|
||||
|
||||
return members, ntotal, nil
|
||||
}
|
||||
|
||||
// roomName returns the name in the event (if type==m.room.name), or
|
||||
// looks it up in roomserver. If there is no name,
|
||||
// m.room.canonical_alias is consulted. Returns an empty string if the
|
||||
// room has no name.
|
||||
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
|
||||
if event.Type() == gomatrixserverlib.MRoomName {
|
||||
name, err := unmarshalRoomName(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
|
||||
req := &rsapi.QueryCurrentStateRequest{
|
||||
RoomID: event.RoomID(),
|
||||
StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple},
|
||||
}
|
||||
var res rsapi.QueryCurrentStateResponse
|
||||
|
||||
if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if eventS := res.StateEvents[roomNameTuple]; eventS != nil {
|
||||
return unmarshalRoomName(eventS)
|
||||
}
|
||||
|
||||
if event.Type() == gomatrixserverlib.MRoomCanonicalAlias {
|
||||
alias, err := unmarshalCanonicalAlias(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if alias != "" {
|
||||
return alias, nil
|
||||
}
|
||||
}
|
||||
|
||||
if event = res.StateEvents[canonicalAliasTuple]; event != nil {
|
||||
return unmarshalCanonicalAlias(event)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var (
|
||||
canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias}
|
||||
roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName}
|
||||
)
|
||||
|
||||
func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) {
|
||||
var nc eventutil.NameContent
|
||||
if err := json.Unmarshal(event.Content(), &nc); err != nil {
|
||||
return "", fmt.Errorf("unmarshaling NameContent: %w", err)
|
||||
}
|
||||
|
||||
return nc.Name, nil
|
||||
}
|
||||
|
||||
func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) {
|
||||
var cac eventutil.CanonicalAliasContent
|
||||
if err := json.Unmarshal(event.Content(), &cac); err != nil {
|
||||
return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err)
|
||||
}
|
||||
|
||||
return cac.Alias, nil
|
||||
}
|
||||
|
||||
// notifyLocal finds the right push actions for a local user, given an event.
|
||||
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
|
||||
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a, tweaks, err := pushrules.ActionsToTweaks(actions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: support coalescing.
|
||||
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": event.RoomID(),
|
||||
"localpart": mem.Localpart,
|
||||
}).Tracef("Push rule evaluation rejected the event")
|
||||
return nil
|
||||
}
|
||||
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := &api.Notification{
|
||||
Actions: actions,
|
||||
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
|
||||
// fields seem to match. room_id should be missing, which
|
||||
// matches the behaviour of FormatSync.
|
||||
Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync),
|
||||
// TODO: this is per-device, but it's not part of the primary
|
||||
// key. So inserting one notification per profile tag doesn't
|
||||
// make sense. What is this supposed to be? Sytests require it
|
||||
// to "work", but they only use a single device.
|
||||
ProfileTag: profileTag,
|
||||
RoomID: event.RoomID(),
|
||||
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
}
|
||||
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We do this after InsertNotification. Thus, this should always return >=1.
|
||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": event.RoomID(),
|
||||
"localpart": mem.Localpart,
|
||||
"num_urls": len(devicesByURLAndFormat),
|
||||
"num_unread": userNumUnreadNotifs,
|
||||
}).Tracef("Notifying single member")
|
||||
|
||||
// Push gateways are out of our control, and we cannot risk
|
||||
// looking up the server on a misbehaving push gateway. Each user
|
||||
// receives a goroutine now that all internal API calls have been
|
||||
// made.
|
||||
//
|
||||
// TODO: think about bounding this to one per user, and what
|
||||
// ordering guarantees we must provide.
|
||||
go func() {
|
||||
// This background processing cannot be tied to a request.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var rejected []*pushgateway.Device
|
||||
for url, fmts := range devicesByURLAndFormat {
|
||||
for format, devices := range fmts {
|
||||
// TODO: support "email".
|
||||
if !strings.HasPrefix(url, "http") {
|
||||
continue
|
||||
}
|
||||
|
||||
// UNSPEC: the specification suggests there can be
|
||||
// more than one device per request. There is at least
|
||||
// one Sytest that expects one HTTP request per
|
||||
// device, rather than per URL. For now, we must
|
||||
// notify each one separately.
|
||||
for _, dev := range devices {
|
||||
rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs))
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"localpart": mem.Localpart,
|
||||
}).WithError(err).Errorf("Unable to notify HTTP pusher")
|
||||
continue
|
||||
}
|
||||
rejected = append(rejected, rej...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(rejected) > 0 {
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// evaluatePushRules fetches and evaluates the push rules of a local
|
||||
// user. Returns actions (including dont_notify).
|
||||
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
||||
if event.Sender() == mem.UserID {
|
||||
// SPEC: Homeservers MUST NOT notify the Push Gateway for
|
||||
// events that the user has sent themselves.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var res api.QueryPushRulesResponse
|
||||
if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ec := &ruleSetEvalContext{
|
||||
ctx: ctx,
|
||||
rsAPI: s.rsAPI,
|
||||
mem: mem,
|
||||
roomID: event.RoomID(),
|
||||
roomSize: roomSize,
|
||||
}
|
||||
eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global)
|
||||
rule, err := eval.MatchEvent(event.Event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rule == nil {
|
||||
// SPEC: If no rules match an event, the homeserver MUST NOT
|
||||
// notify the Push Gateway for that event.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": event.RoomID(),
|
||||
"localpart": mem.Localpart,
|
||||
"rule_id": rule.RuleID,
|
||||
}).Tracef("Matched a push rule")
|
||||
|
||||
return rule.Actions, nil
|
||||
}
|
||||
|
||||
type ruleSetEvalContext struct {
|
||||
ctx context.Context
|
||||
rsAPI rsapi.RoomserverInternalAPI
|
||||
mem *localMembership
|
||||
roomID string
|
||||
roomSize int
|
||||
}
|
||||
|
||||
func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName }
|
||||
|
||||
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
|
||||
|
||||
func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
|
||||
req := &rsapi.QueryLatestEventsAndStateRequest{
|
||||
RoomID: rse.roomID,
|
||||
StateToFetch: []gomatrixserverlib.StateKeyTuple{
|
||||
{EventType: gomatrixserverlib.MRoomPowerLevels},
|
||||
},
|
||||
}
|
||||
var res rsapi.QueryLatestEventsAndStateResponse
|
||||
if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, ev := range res.StateEvents {
|
||||
if ev.Type() != gomatrixserverlib.MRoomPowerLevels {
|
||||
continue
|
||||
}
|
||||
|
||||
plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// localPushDevices pushes to the configured devices of a local
|
||||
// user. The map keys are [url][format].
|
||||
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var profileTag string
|
||||
devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices))
|
||||
for _, pusherDevice := range pusherDevices {
|
||||
if profileTag == "" {
|
||||
profileTag = pusherDevice.Pusher.ProfileTag
|
||||
}
|
||||
|
||||
url := pusherDevice.URL
|
||||
if devicesByURL[url] == nil {
|
||||
devicesByURL[url] = make(map[string][]*pushgateway.Device, 2)
|
||||
}
|
||||
devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device)
|
||||
}
|
||||
|
||||
return devicesByURL, profileTag, nil
|
||||
}
|
||||
|
||||
// notifyHTTP performs a notificatation to a Push Gateway.
|
||||
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"url": url,
|
||||
"localpart": localpart,
|
||||
"num_devices": len(devices),
|
||||
})
|
||||
|
||||
var req pushgateway.NotifyRequest
|
||||
switch format {
|
||||
case "event_id_only":
|
||||
req = pushgateway.NotifyRequest{
|
||||
Notification: pushgateway.Notification{
|
||||
Counts: &pushgateway.Counts{},
|
||||
Devices: devices,
|
||||
EventID: event.EventID(),
|
||||
RoomID: event.RoomID(),
|
||||
},
|
||||
}
|
||||
|
||||
default:
|
||||
req = pushgateway.NotifyRequest{
|
||||
Notification: pushgateway.Notification{
|
||||
Content: event.Content(),
|
||||
Counts: &pushgateway.Counts{
|
||||
Unread: userNumUnreadNotifs,
|
||||
},
|
||||
Devices: devices,
|
||||
EventID: event.EventID(),
|
||||
ID: event.EventID(),
|
||||
RoomID: event.RoomID(),
|
||||
RoomName: roomName,
|
||||
Sender: event.Sender(),
|
||||
Type: event.Type(),
|
||||
},
|
||||
}
|
||||
if mem, err := event.Membership(); err == nil {
|
||||
req.Notification.Membership = mem
|
||||
}
|
||||
if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
|
||||
req.Notification.UserIsTarget = true
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debugf("Notifying push gateway %s", url)
|
||||
var res pushgateway.NotifyResponse
|
||||
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
|
||||
logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
|
||||
return nil, err
|
||||
}
|
||||
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
|
||||
|
||||
if len(res.Rejected) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
devMap := make(map[string]*pushgateway.Device, len(devices))
|
||||
for _, d := range devices {
|
||||
devMap[d.PushKey] = d
|
||||
}
|
||||
rejected := make([]*pushgateway.Device, 0, len(res.Rejected))
|
||||
for _, pushKey := range res.Rejected {
|
||||
d := devMap[pushKey]
|
||||
if d != nil {
|
||||
rejected = append(rejected, d)
|
||||
}
|
||||
}
|
||||
|
||||
return rejected, nil
|
||||
}
|
||||
|
||||
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
||||
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": devices[0].AppID,
|
||||
"num_devices": len(devices),
|
||||
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
||||
|
||||
for _, d := range devices {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -20,6 +20,8 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -27,16 +29,22 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/appservice/types"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/producers"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
type UserInternalAPI struct {
|
||||
DB storage.Database
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
DB storage.Database
|
||||
SyncProducer *producers.SyncAPI
|
||||
|
||||
DisableTLSValidation bool
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
// AppServices is the list of all registered AS
|
||||
AppServices []config.ApplicationService
|
||||
KeyAPI keyapi.KeyInternalAPI
|
||||
|
|
@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
|||
}
|
||||
res.Keys = result
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
|
||||
if req.Limit == 0 || req.Limit > 1000 {
|
||||
req.Limit = 1000
|
||||
}
|
||||
|
||||
var fromID int64
|
||||
var err error
|
||||
if req.From != "" {
|
||||
fromID, err = strconv.ParseInt(req.From, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("QueryNotifications: parsing 'from': %w", err)
|
||||
}
|
||||
}
|
||||
var filter tables.NotificationFilter = tables.AllNotifications
|
||||
if req.Only == "highlight" {
|
||||
filter = tables.HighlightNotifications
|
||||
}
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notifs == nil {
|
||||
// This ensures empty is JSON-encoded as [] instead of null.
|
||||
notifs = []*api.Notification{}
|
||||
}
|
||||
res.Notifications = notifs
|
||||
if lastID >= 0 {
|
||||
res.NextToken = strconv.FormatInt(lastID+1, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error {
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"localpart": req.Localpart,
|
||||
"pushkey": req.Pusher.PushKey,
|
||||
"display_name": req.Pusher.AppDisplayName,
|
||||
}).Info("PerformPusherCreation")
|
||||
if !req.Append {
|
||||
err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if req.Pusher.Kind == "" {
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range pushers {
|
||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||
if pushers[i].SessionID != req.SessionID {
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
var err error
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPushRulesPut(
|
||||
ctx context.Context,
|
||||
req *api.PerformPushRulesPutRequest,
|
||||
_ *struct{},
|
||||
) error {
|
||||
bs, err := json.Marshal(&req.RuleSets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userReq := api.InputAccountDataRequest{
|
||||
UserID: req.UserID,
|
||||
DataType: pushRulesAccountDataType,
|
||||
AccountData: json.RawMessage(bs),
|
||||
}
|
||||
var userRes api.InputAccountDataResponse // empty
|
||||
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
userReq := api.QueryAccountDataRequest{
|
||||
UserID: req.UserID,
|
||||
DataType: pushRulesAccountDataType,
|
||||
}
|
||||
var userRes api.QueryAccountDataResponse
|
||||
if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil {
|
||||
return err
|
||||
}
|
||||
bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType]
|
||||
if ok {
|
||||
// Legacy Dendrite users will have completely empty push rules, so we should
|
||||
// detect that situation and set some defaults.
|
||||
var rules struct {
|
||||
G struct {
|
||||
Content []json.RawMessage `json:"content"`
|
||||
Override []json.RawMessage `json:"override"`
|
||||
Room []json.RawMessage `json:"room"`
|
||||
Sender []json.RawMessage `json:"sender"`
|
||||
Underride []json.RawMessage `json:"underride"`
|
||||
} `json:"global"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(bs), &rules); err == nil {
|
||||
count := len(rules.G.Content) + len(rules.G.Override) +
|
||||
len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride)
|
||||
ok = count > 0
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
// If we didn't find any default push rules then we should just generate some
|
||||
// fresh ones.
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||
}
|
||||
if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil {
|
||||
return fmt.Errorf("failed to save default push rules: %w", err)
|
||||
}
|
||||
res.RuleSets = pushRuleSets
|
||||
return nil
|
||||
}
|
||||
var data pushrules.AccountRuleSets
|
||||
if err := json.Unmarshal([]byte(bs), &data); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed")
|
||||
return err
|
||||
}
|
||||
res.RuleSets = &data
|
||||
return nil
|
||||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ const (
|
|||
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
||||
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
||||
PerformKeyBackupPath = "/userapi/performKeyBackup"
|
||||
PerformPusherSetPath = "/pushserver/performPusherSet"
|
||||
PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
|
||||
PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
|
||||
|
||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||
QueryProfilePath = "/userapi/queryProfile"
|
||||
|
|
@ -46,6 +49,9 @@ const (
|
|||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||
QueryPushersPath = "/pushserver/queryPushers"
|
||||
QueryPushRulesPath = "/pushserver/queryPushRules"
|
||||
QueryNotificationsPath = "/pushserver/queryNotifications"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
|
@ -248,4 +254,59 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query
|
|||
if err != nil {
|
||||
res.Error = err.Error()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
|
||||
defer span.Finish()
|
||||
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformPusherSet(
|
||||
ctx context.Context,
|
||||
request *api.PerformPusherSetRequest,
|
||||
response *struct{},
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformPusherSetPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformPusherDeletionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryPushersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformPushRulesPut(
|
||||
ctx context.Context,
|
||||
request *api.PerformPushRulesPutRequest,
|
||||
response *struct{},
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformPushRulesPutPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryPushRulesPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryNotificationsPath,
|
||||
httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
|
||||
var request api.QueryNotificationsRequest
|
||||
var response api.QueryNotificationsResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(PerformPusherSetPath,
|
||||
httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
|
||||
request := api.PerformPusherSetRequest{}
|
||||
response := struct{}{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(PerformPusherDeletionPath,
|
||||
httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
|
||||
request := api.PerformPusherDeletionRequest{}
|
||||
response := struct{}{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(QueryPushersPath,
|
||||
httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryPushersRequest{}
|
||||
response := api.QueryPushersResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(PerformPushRulesPutPath,
|
||||
httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
|
||||
request := api.PerformPushRulesPutRequest{}
|
||||
response := struct{}{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(QueryPushRulesPath,
|
||||
httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryPushRulesRequest{}
|
||||
response := api.QueryPushRulesResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
104
userapi/producers/syncapi.go
Normal file
104
userapi/producers/syncapi.go
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
package producers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type JetStreamPublisher interface {
|
||||
PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error)
|
||||
}
|
||||
|
||||
// SyncAPI produces messages for the Sync API server to consume.
|
||||
type SyncAPI struct {
|
||||
db storage.Database
|
||||
producer JetStreamPublisher
|
||||
clientDataTopic string
|
||||
notificationDataTopic string
|
||||
}
|
||||
|
||||
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
|
||||
return &SyncAPI{
|
||||
db: db,
|
||||
producer: js,
|
||||
clientDataTopic: clientDataTopic,
|
||||
notificationDataTopic: notificationDataTopic,
|
||||
}
|
||||
}
|
||||
|
||||
// SendAccountData sends account data to the Sync API server.
|
||||
func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.clientDataTopic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, userID)
|
||||
|
||||
var err error
|
||||
m.Data, err = json.Marshal(eventutil.AccountData{
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": roomID,
|
||||
"data_type": dataType,
|
||||
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
|
||||
|
||||
_, err = p.producer.PublishMsg(m)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAndSendNotificationData reads the database and sends data about unread
|
||||
// notifications to the Sync API server.
|
||||
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.sendNotificationData(userID, &eventutil.NotificationData{
|
||||
RoomID: roomID,
|
||||
UnreadHighlightCount: int(nhighlight),
|
||||
UnreadNotificationCount: int(ntotal),
|
||||
})
|
||||
}
|
||||
|
||||
// sendNotificationData sends data about unread notifications to the Sync API server.
|
||||
func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.notificationDataTopic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, userID)
|
||||
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"room_id": data.RoomID,
|
||||
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
|
||||
|
||||
_, err = p.producer.PublishMsg(m)
|
||||
return err
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
|
|
@ -90,6 +91,18 @@ type Database interface {
|
|||
// May return sql.ErrNoRows.
|
||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||
|
||||
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||
|
||||
AllUsers(ctx context.Context) (result int64, err error)
|
||||
NonBridgedUsers(ctx context.Context) (result int64, err error)
|
||||
RegisteredUserByType(ctx context.Context) (map[string]int64, error)
|
||||
|
|
|
|||
219
userapi/storage/postgres/notifications_table.go
Normal file
219
userapi/storage/postgres/notifications_table.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type notificationsStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteUpToStmt *sql.Stmt
|
||||
updateReadStmt *sql.Stmt
|
||||
selectStmt *sql.Stmt
|
||||
selectCountStmt *sql.Stmt
|
||||
selectRoomCountsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
ts_ms BIGINT NOT NULL,
|
||||
highlight BOOLEAN NOT NULL,
|
||||
notification_json TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
|
||||
func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
|
||||
s := ¬ificationsStatements{}
|
||||
_, err := db.Exec(notificationSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertStmt, insertNotificationSQL},
|
||||
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
|
||||
{&s.updateReadStmt, updateNotificationReadSQL},
|
||||
{&s.selectStmt, selectNotificationSQL},
|
||||
{&s.selectCountStmt, selectNotificationCountSQL},
|
||||
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
// data and (2) avoid difficult-to-debug inconsistency bugs.
|
||||
nn.RoomID = ""
|
||||
nn.TS, nn.Read = 0, false
|
||||
bs, err := json.Marshal(nn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
var maxID int64 = -1
|
||||
var notifs []*api.Notification
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var roomID string
|
||||
var ts gomatrixserverlib.Timestamp
|
||||
var read bool
|
||||
var jsonStr string
|
||||
err = rows.Scan(
|
||||
&id,
|
||||
&roomID,
|
||||
&ts,
|
||||
&read,
|
||||
&jsonStr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var n api.Notification
|
||||
err := json.Unmarshal([]byte(jsonStr), &n)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
n.RoomID = roomID
|
||||
n.TS = ts
|
||||
n.Read = read
|
||||
notifs = append(notifs, &n)
|
||||
|
||||
if maxID < id {
|
||||
maxID = id
|
||||
}
|
||||
}
|
||||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
if rows.Next() {
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
return 0, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
|
||||
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
if rows.Next() {
|
||||
var total, highlight int64
|
||||
if err := rows.Scan(&total, &highlight); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return total, highlight, nil
|
||||
}
|
||||
return 0, 0, rows.Err()
|
||||
}
|
||||
157
userapi/storage/postgres/pusher_table.go
Normal file
157
userapi/storage/postgres/pusher_table.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
const pushersSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
app_id TEXT NOT NULL,
|
||||
app_display_name TEXT NOT NULL,
|
||||
device_display_name TEXT NOT NULL,
|
||||
pushkey TEXT NOT NULL,
|
||||
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
|
||||
lang TEXT NOT NULL,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- For faster deleting by app_id, pushkey pair.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
||||
func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
|
||||
s := &pushersStatements{}
|
||||
_, err := db.Exec(pushersSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertPusherStmt, insertPusherSQL},
|
||||
{&s.selectPushersStmt, selectPushersSQL},
|
||||
{&s.deletePusherStmt, deletePusherSQL},
|
||||
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type pushersStatements struct {
|
||||
insertPusherStmt *sql.Stmt
|
||||
selectPushersStmt *sql.Stmt
|
||||
deletePusherStmt *sql.Stmt
|
||||
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// insertPusher creates a new pusher.
|
||||
// Returns an error if the user already has a pusher with the given pusher pushkey.
|
||||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var pusher api.Pusher
|
||||
var data []byte
|
||||
err = rows.Scan(
|
||||
&pusher.SessionID,
|
||||
&pusher.PushKey,
|
||||
&pusher.PushKeyTS,
|
||||
&pusher.Kind,
|
||||
&pusher.AppID,
|
||||
&pusher.AppDisplayName,
|
||||
&pusher.DeviceDisplayName,
|
||||
&pusher.ProfileTag,
|
||||
&pusher.Language,
|
||||
&data)
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
err := json.Unmarshal(data, &pusher.Data)
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
pushers = append(pushers, pusher)
|
||||
}
|
||||
|
||||
logrus.Debugf("Database returned %d pushers", len(pushers))
|
||||
return pushers, rows.Err()
|
||||
}
|
||||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) DeletePushers(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
|
||||
return err
|
||||
}
|
||||
|
|
@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
|
||||
}
|
||||
pusherTable, err := NewPostgresPusherTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
|
||||
}
|
||||
notificationsTable, err := NewPostgresNotificationTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
|
||||
}
|
||||
statsTable, err := NewPostgresStatsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
|
|
@ -99,6 +107,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
|
@ -47,6 +48,8 @@ type Database struct {
|
|||
KeyBackupVersions tables.KeyBackupVersionTable
|
||||
Devices tables.DevicesTable
|
||||
LoginTokens tables.LoginTokenTable
|
||||
Notifications tables.NotificationTable
|
||||
Pushers tables.PusherTable
|
||||
Stats tables.StatsTable
|
||||
LoginTokenLifetime time.Duration
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
|
|
@ -161,15 +164,12 @@ func (d *Database) createAccount(
|
|||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
"room": [],
|
||||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`)); err != nil {
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
|
|
@ -672,6 +672,97 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
|||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) UpsertPusher(
|
||||
ctx context.Context, p api.Pusher, localpart string,
|
||||
) error {
|
||||
data, err := json.Marshal(p.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Pushers.InsertPusher(
|
||||
ctx, txn,
|
||||
p.SessionID,
|
||||
p.PushKey,
|
||||
p.PushKeyTS,
|
||||
p.Kind,
|
||||
p.AppID,
|
||||
p.AppDisplayName,
|
||||
p.DeviceDisplayName,
|
||||
p.ProfileTag,
|
||||
p.Language,
|
||||
string(data),
|
||||
localpart)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Pusher, error) {
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
||||
}
|
||||
|
||||
// RemovePusher deletes one pusher
|
||||
// Invoked when `append` is true and `kind` is null in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePusher(
|
||||
ctx context.Context, appid, pushkey, localpart string,
|
||||
) error {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePushers deletes all pushers that match given App Id and Push Key pair.
|
||||
// Invoked when `append` parameter is false in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePushers(
|
||||
ctx context.Context, appid, pushkey string,
|
||||
) error {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) AllUsers(ctx context.Context) (result int64, err error) {
|
||||
return d.Stats.AllUsers(ctx, nil)
|
||||
}
|
||||
|
|
|
|||
219
userapi/storage/sqlite3/notifications_table.go
Normal file
219
userapi/storage/sqlite3/notifications_table.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type notificationsStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteUpToStmt *sql.Stmt
|
||||
updateReadStmt *sql.Stmt
|
||||
selectStmt *sql.Stmt
|
||||
selectCountStmt *sql.Stmt
|
||||
selectRoomCountsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
localpart TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
ts_ms BIGINT NOT NULL,
|
||||
highlight BOOLEAN NOT NULL,
|
||||
notification_json TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
|
||||
func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
|
||||
s := ¬ificationsStatements{}
|
||||
_, err := db.Exec(notificationSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertStmt, insertNotificationSQL},
|
||||
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
|
||||
{&s.updateReadStmt, updateNotificationReadSQL},
|
||||
{&s.selectStmt, selectNotificationSQL},
|
||||
{&s.selectCountStmt, selectNotificationCountSQL},
|
||||
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
// data and (2) avoid difficult-to-debug inconsistency bugs.
|
||||
nn.RoomID = ""
|
||||
nn.TS, nn.Read = 0, false
|
||||
bs, err := json.Marshal(nn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
var maxID int64 = -1
|
||||
var notifs []*api.Notification
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var roomID string
|
||||
var ts gomatrixserverlib.Timestamp
|
||||
var read bool
|
||||
var jsonStr string
|
||||
err = rows.Scan(
|
||||
&id,
|
||||
&roomID,
|
||||
&ts,
|
||||
&read,
|
||||
&jsonStr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var n api.Notification
|
||||
err := json.Unmarshal([]byte(jsonStr), &n)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
n.RoomID = roomID
|
||||
n.TS = ts
|
||||
n.Read = read
|
||||
notifs = append(notifs, &n)
|
||||
|
||||
if maxID < id {
|
||||
maxID = id
|
||||
}
|
||||
}
|
||||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
if rows.Next() {
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
return 0, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
|
||||
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
|
||||
|
||||
if rows.Next() {
|
||||
var total, highlight int64
|
||||
if err := rows.Scan(&total, &highlight); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return total, highlight, nil
|
||||
}
|
||||
return 0, 0, rows.Err()
|
||||
}
|
||||
157
userapi/storage/sqlite3/pusher_table.go
Normal file
157
userapi/storage/sqlite3/pusher_table.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
const pushersSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
app_id TEXT NOT NULL,
|
||||
app_display_name TEXT NOT NULL,
|
||||
device_display_name TEXT NOT NULL,
|
||||
pushkey TEXT NOT NULL,
|
||||
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
|
||||
lang TEXT NOT NULL,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- For faster deleting by app_id, pushkey pair.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
||||
func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
|
||||
s := &pushersStatements{}
|
||||
_, err := db.Exec(pushersSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertPusherStmt, insertPusherSQL},
|
||||
{&s.selectPushersStmt, selectPushersSQL},
|
||||
{&s.deletePusherStmt, deletePusherSQL},
|
||||
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type pushersStatements struct {
|
||||
insertPusherStmt *sql.Stmt
|
||||
selectPushersStmt *sql.Stmt
|
||||
deletePusherStmt *sql.Stmt
|
||||
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// insertPusher creates a new pusher.
|
||||
// Returns an error if the user already has a pusher with the given pusher pushkey.
|
||||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var pusher api.Pusher
|
||||
var data []byte
|
||||
err = rows.Scan(
|
||||
&pusher.SessionID,
|
||||
&pusher.PushKey,
|
||||
&pusher.PushKeyTS,
|
||||
&pusher.Kind,
|
||||
&pusher.AppID,
|
||||
&pusher.AppDisplayName,
|
||||
&pusher.DeviceDisplayName,
|
||||
&pusher.ProfileTag,
|
||||
&pusher.Language,
|
||||
&data)
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
err := json.Unmarshal(data, &pusher.Data)
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
}
|
||||
pushers = append(pushers, pusher)
|
||||
}
|
||||
|
||||
logrus.Debugf("Database returned %d pushers", len(pushers))
|
||||
return pushers, rows.Err()
|
||||
}
|
||||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
) error {
|
||||
_, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) DeletePushers(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey string,
|
||||
) error {
|
||||
_, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
|
||||
return err
|
||||
}
|
||||
|
|
@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
|
||||
}
|
||||
pusherTable, err := NewSQLitePusherTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
|
||||
}
|
||||
notificationsTable, err := NewSQLiteNotificationTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
|
||||
}
|
||||
statsTable, err := NewSQLiteStatsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
|
|
@ -100,6 +108,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
|
|
@ -94,6 +95,22 @@ type ThreePIDTable interface {
|
|||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type NotificationTable interface {
|
||||
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
|
||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
|
||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
|
||||
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
|
||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
}
|
||||
|
||||
type StatsTable interface {
|
||||
AllUsers(ctx context.Context, txn *sql.Tx) (result int64, err error)
|
||||
NonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error)
|
||||
|
|
@ -102,4 +119,27 @@ type StatsTable interface {
|
|||
MonthlyUsers(ctx context.Context, txn *sql.Tx) (result int64, err error)
|
||||
R30Users(ctx context.Context, txn *sql.Tx) (map[string]int64, error)
|
||||
R30UsersV2(ctx context.Context, txn *sql.Tx) (map[string]int64, error)
|
||||
}
|
||||
}
|
||||
|
||||
type NotificationFilter uint32
|
||||
|
||||
const (
|
||||
// HighlightNotifications returns notifications that had a
|
||||
// "highlight" tweak assigned to them from evaluating push rules.
|
||||
HighlightNotifications NotificationFilter = 1 << iota
|
||||
|
||||
// NonHighlightNotifications returns notifications that don't
|
||||
// match HighlightNotifications.
|
||||
NonHighlightNotifications
|
||||
|
||||
// NoNotifications is a filter to exclude all types of
|
||||
// notifications. It's useful as a zero value, but isn't likely to
|
||||
// be used in a call to Notifications.Select*.
|
||||
NoNotifications NotificationFilter = 0
|
||||
|
||||
// AllNotifications is a filter to include all types of
|
||||
// notifications in Notifications.Select*. Note that PostgreSQL
|
||||
// balks if this doesn't fit in INTEGER, even though we use
|
||||
// uint32.
|
||||
AllNotifications NotificationFilter = (1 << 31) - 1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,13 +25,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
version "github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/consumers"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/producers"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -46,28 +52,51 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
|||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||
base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI,
|
||||
appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||
rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client,
|
||||
) api.UserInternalAPI {
|
||||
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||
}
|
||||
|
||||
return newInternalAPI(db, cfg, appServices, keyAPI)
|
||||
}
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
func newInternalAPI(
|
||||
db storage.Database,
|
||||
cfg *config.UserAPI,
|
||||
appServices []config.ApplicationService,
|
||||
keyAPI keyapi.KeyInternalAPI,
|
||||
) api.UserInternalAPI {
|
||||
return &internal.UserInternalAPI{
|
||||
DB: db,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AppServices: appServices,
|
||||
KeyAPI: keyAPI,
|
||||
syncProducer := producers.NewSyncAPI(
|
||||
db, js,
|
||||
// TODO: user API should handle syncs for account data. Right now,
|
||||
// it's handled by clientapi, and hence uses its topic. When user
|
||||
// API handles it for all account data, we can remove it from
|
||||
// here.
|
||||
cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
|
||||
cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
|
||||
)
|
||||
|
||||
userAPI := &internal.UserInternalAPI{
|
||||
DB: db,
|
||||
SyncProducer: syncProducer,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AppServices: appServices,
|
||||
KeyAPI: keyAPI,
|
||||
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
|
||||
}
|
||||
|
||||
readConsumer := consumers.NewOutputReadUpdateConsumer(
|
||||
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer,
|
||||
)
|
||||
if err := readConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start user API read update consumer")
|
||||
}
|
||||
|
||||
eventConsumer := consumers.NewOutputStreamEventConsumer(
|
||||
base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
|
||||
)
|
||||
if err := eventConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start user API streamed event consumer")
|
||||
}
|
||||
|
||||
return userAPI
|
||||
}
|
||||
|
||||
type phoneHomeStats struct {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
|
@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
|
|||
},
|
||||
}
|
||||
|
||||
return newInternalAPI(accountDB, cfg, nil, nil), accountDB
|
||||
return &internal.UserInternalAPI{
|
||||
DB: accountDB,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
}, accountDB
|
||||
}
|
||||
|
||||
func TestQueryProfile(t *testing.T) {
|
||||
|
|
|
|||
100
userapi/util/devices.go
Normal file
100
userapi/util/devices.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type PusherDevice struct {
|
||||
Device pushgateway.Device
|
||||
Pusher *api.Pusher
|
||||
URL string
|
||||
Format string
|
||||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
devices := make([]*PusherDevice, 0, len(pushers))
|
||||
for _, pusher := range pushers {
|
||||
var url, format string
|
||||
data := pusher.Data
|
||||
switch pusher.Kind {
|
||||
case api.EmailKind:
|
||||
url = "mailto:"
|
||||
|
||||
case api.HTTPKind:
|
||||
// TODO: The spec says only event_id_only is supported,
|
||||
// but Sytests assume "" means "full notification".
|
||||
fmtIface := pusher.Data["format"]
|
||||
var ok bool
|
||||
format, ok = fmtIface.(string)
|
||||
if ok && format != "event_id_only" {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id": pusher.AppID,
|
||||
}).Errorf("Only data.format event_id_only or empty is supported")
|
||||
continue
|
||||
}
|
||||
|
||||
urlIface := pusher.Data["url"]
|
||||
url, ok = urlIface.(string)
|
||||
if !ok {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id": pusher.AppID,
|
||||
}).Errorf("No data.url configured for HTTP Pusher")
|
||||
continue
|
||||
}
|
||||
data = mapWithout(data, "url")
|
||||
|
||||
default:
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id": pusher.AppID,
|
||||
"kind": pusher.Kind,
|
||||
}).Errorf("Unhandled pusher kind")
|
||||
continue
|
||||
}
|
||||
|
||||
devices = append(devices, &PusherDevice{
|
||||
Device: pushgateway.Device{
|
||||
AppID: pusher.AppID,
|
||||
Data: data,
|
||||
PushKey: pusher.PushKey,
|
||||
PushKeyTS: pusher.PushKeyTS,
|
||||
Tweaks: tweaks,
|
||||
},
|
||||
Pusher: &pusher,
|
||||
URL: url,
|
||||
Format: format,
|
||||
})
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// mapWithout returns a shallow copy of the map, without the given
|
||||
// key. Returns nil if the resulting map is empty.
|
||||
func mapWithout(m map[string]interface{}, key string) map[string]interface{} {
|
||||
ret := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
// The specification says we do not send "url".
|
||||
if k == key {
|
||||
continue
|
||||
}
|
||||
ret[k] = v
|
||||
}
|
||||
if len(ret) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ret
|
||||
}
|
||||
76
userapi/util/notify.go
Normal file
76
userapi/util/notify.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NotifyUserCountsAsync sends notifications to a local user's
|
||||
// notification destinations. Database lookups run synchronously, but
|
||||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(pusherDevices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": pusherDevices[0].Device.AppID,
|
||||
"pushkey": pusherDevices[0].Device.PushKey,
|
||||
}).Tracef("Notifying HTTP push gateway about notification counts")
|
||||
|
||||
// TODO: think about bounding this to one per user, and what
|
||||
// ordering guarantees we must provide.
|
||||
go func() {
|
||||
// This background processing cannot be tied to a request.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// TODO: we could batch all devices with the same URL, but
|
||||
// Sytest requires consumers/roomserver.go to do it
|
||||
// one-by-one, so we do the same here.
|
||||
for _, pusherDevice := range pusherDevices {
|
||||
// TODO: support "email".
|
||||
if !strings.HasPrefix(pusherDevice.URL, "http") {
|
||||
continue
|
||||
}
|
||||
|
||||
req := pushgateway.NotifyRequest{
|
||||
Notification: pushgateway.Notification{
|
||||
Counts: &pushgateway.Counts{
|
||||
Unread: int(userNumUnreadNotifs),
|
||||
},
|
||||
Devices: []*pushgateway.Device{&pusherDevice.Device},
|
||||
},
|
||||
}
|
||||
if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": pusherDevice.Device.AppID,
|
||||
"pushkey": pusherDevice.Device.PushKey,
|
||||
}).WithError(err).Error("HTTP push gateway request failed")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in a new issue