Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking

This commit is contained in:
Till Faelligen 2022-03-04 09:18:34 +01:00
commit e6e62497c9
137 changed files with 6820 additions and 1294 deletions

View file

@ -63,7 +63,7 @@ jobs:
# Run Complement # Run Complement
- run: | - run: |
set -o pipefail && set -o pipefail &&
go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
env: env:

4
.gitignore vendored
View file

@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
# Test dependencies # Test dependencies
test/wasm/node_modules test/wasm/node_modules
media_store/ # Ignore complement folder when running locally
complement/
media_store/

View file

@ -318,6 +318,17 @@ user_api:
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
# Configuration for the Push Server API.
push_server:
internal_api:
listen: http://localhost:7782
connect: http://localhost:7782
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Configuration for Opentracing. # Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
# how this works and how to set it up. # how this works and how to set it up.

View file

@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI) keyAPI.SetUserAPI(m.userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
delete(u.Sessions, sessionID) delete(u.Sessions, sessionID)
} }
// Challenge returns an HTTP 401 with the supported flows for authenticating type Challenge struct {
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
return &util.JSONResponse{
Code: 401,
JSON: struct {
Completed []string `json:"completed"` Completed []string `json:"completed"`
Flows []userInteractiveFlow `json:"flows"` Flows []userInteractiveFlow `json:"flows"`
Session string `json:"session"` Session string `json:"session"`
// TODO: Return any additional `params` // TODO: Return any additional `params`
Params map[string]interface{} `json:"params"` Params map[string]interface{} `json:"params"`
}{ }
u.Completed,
u.Flows, // Challenge returns an HTTP 401 with the supported flows for authenticating
sessionID, func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
make(map[string]interface{}), return &util.JSONResponse{
Code: 401,
JSON: Challenge{
Completed: u.Completed,
Flows: u.Flows,
Session: sessionID,
Params: make(map[string]interface{}),
}, },
} }
} }

View file

@ -59,7 +59,8 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
accountsDB, userAPI, federation, accountsDB, userAPI, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg,
) )
} }

View file

@ -30,7 +30,7 @@ type SyncAPIProducer struct {
} }
// SendData sends account data to the sync API server // SendData sends account data to the sync API server
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
m := &nats.Msg{ m := &nats.Msg{
Subject: p.Topic, Subject: p.Topic,
Header: nats.Header{}, Header: nats.Header{},
@ -40,6 +40,7 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
data := eventutil.AccountData{ data := eventutil.AccountData{
RoomID: roomID, RoomID: roomID,
Type: dataType, Type: dataType,
ReadMarker: readMarker,
} }
var err error var err error
m.Data, err = json.Marshal(data) m.Data, err = json.Marshal(data)

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -127,7 +128,7 @@ func SaveAccountData(
} }
// TODO: user API should do this since it's account data // TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -138,11 +139,6 @@ func SaveAccountData(
} }
} }
type readMarkerJSON struct {
FullyRead string `json:"m.fully_read"`
Read string `json:"m.read"`
}
type fullyReadEvent struct { type fullyReadEvent struct {
EventID string `json:"event_id"` EventID string `json:"event_id"`
} }
@ -159,7 +155,7 @@ func SaveReadMarker(
return *resErr return *resErr
} }
var r readMarkerJSON var r eventutil.ReadMarkerJSON
resErr = httputil.UnmarshalJSONRequest(req, &r) resErr = httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
@ -189,7 +185,7 @@ func SaveReadMarker(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil { if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/tidwall/gjson"
) )
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices // https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
@ -163,6 +164,15 @@ func DeleteDeviceById(
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
var (
deleteOK bool
sessionID string
)
defer func() {
if deleteOK {
sessions.deleteSession(sessionID)
}
}()
ctx := req.Context() ctx := req.Context()
defer req.Body.Close() // nolint:errcheck defer req.Body.Close() // nolint:errcheck
bodyBytes, err := ioutil.ReadAll(req.Body) bodyBytes, err := ioutil.ReadAll(req.Body)
@ -172,8 +182,29 @@ func DeleteDeviceById(
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
} }
} }
// check that we know this session, and it matches with the device to delete
s := gjson.GetBytes(bodyBytes, "auth.session").Str
if dev, ok := sessions.getDeviceToDelete(s); ok {
if dev != deviceID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("session & device mismatch"),
}
}
}
if s != "" {
sessionID = s
}
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
if errRes != nil { if errRes != nil {
switch data := errRes.JSON.(type) {
case auth.Challenge:
sessions.addDeviceToDelete(data.Session, deviceID)
default:
}
return *errRes return *errRes
} }
@ -201,6 +232,8 @@ func DeleteDeviceById(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
deleteOK = true
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: struct{}{},

View file

@ -0,0 +1,63 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"strconv"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetNotifications handles /_matrix/client/r0/notifications
func GetNotifications(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
var limit int64
if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
var err error
limit, err = strconv.ParseInt(limitStr, 10, 64)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
return jsonerror.InternalServerError()
}
}
var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart,
From: req.URL.Query().Get("from"),
Limit: int(limit),
Only: req.URL.Query().Get("only"),
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
return jsonerror.InternalServerError()
}
util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
return util.JSONResponse{
Code: http.StatusOK,
JSON: queryRes,
}
}

View file

@ -12,6 +12,7 @@ import (
userdb "github.com/matrix-org/dendrite/userapi/storage" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type newPasswordRequest struct { type newPasswordRequest struct {
@ -37,6 +38,11 @@ func Password(
var r newPasswordRequest var r newPasswordRequest
r.LogoutDevices = true r.LogoutDevices = true
logrus.WithFields(logrus.Fields{
"sessionId": device.SessionID,
"userId": device.UserID,
}).Debug("Changing password")
// Unmarshal the request. // Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
@ -116,6 +122,15 @@ func Password(
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart,
SessionID: device.SessionID,
}
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
return jsonerror.InternalServerError()
}
} }
// Return a success code. // Return a success code.

View file

@ -286,7 +286,7 @@ func SetDisplayName(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

114
clientapi/routing/pusher.go Normal file
View file

@ -0,0 +1,114 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"net/url"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetPushers handles /_matrix/client/r0/pushers
func GetPushers(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart,
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
return jsonerror.InternalServerError()
}
for i := range queryRes.Pushers {
queryRes.Pushers[i].SessionID = 0
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: queryRes,
}
}
// SetPusher handles /_matrix/client/r0/pushers/set
// This endpoint allows the creation, modification and deletion of pushers for this user ID.
// The behaviour of this endpoint varies depending on the values in the JSON body.
func SetPusher(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
body := userapi.PerformPusherSetRequest{}
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
return *resErr
}
if len(body.AppID) > 64 {
return invalidParam("length of app_id must be no more than 64 characters")
}
if len(body.PushKey) > 512 {
return invalidParam("length of pushkey must be no more than 512 bytes")
}
uInt := body.Data["url"]
if uInt != nil {
u, ok := uInt.(string)
if !ok {
return invalidParam("url must be string")
}
if u != "" {
var pushUrl *url.URL
pushUrl, err = url.Parse(u)
if err != nil {
return invalidParam("malformed url passed")
}
if pushUrl.Scheme != "https" {
return invalidParam("only https scheme is allowed")
}
}
}
body.Localpart = localpart
body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func invalidParam(msg string) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam(msg),
}
}

View file

@ -0,0 +1,386 @@
package routing
import (
"context"
"encoding/json"
"io"
"net/http"
"reflect"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/pushrules"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
if eerr, ok := err.(*jsonerror.MatrixError); ok {
var status int
switch eerr.ErrCode {
case "M_INVALID_ARGUMENT_VALUE":
status = http.StatusBadRequest
case "M_NOT_FOUND":
status = http.StatusNotFound
default:
status = http.StatusInternalServerError
}
return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
}
util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
return jsonerror.InternalServerError()
}
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: ruleSets,
}
}
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: ruleSet,
}
}
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: *rulesPtr,
}
}
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: (*rulesPtr)[i],
}
}
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
var newRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
return errorResponse(ctx, err, "JSON Decode failed")
}
newRule.RuleID = ruleID
errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
if len(errs) > 0 {
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
// Modify rule at the same index.
// TODO: The spec does not say what to do in this case, but
// this feels reasonable.
*((*rulesPtr)[i]) = newRule
util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
} else {
if i >= 0 {
// Delete old rule.
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
} else {
// SPEC: When creating push rules, they MUST be enabled by default.
//
// TODO: it's unclear if we must reject disabled rules, or force
// the value to true. Sytests fail if we don't force it.
newRule.Enabled = true
}
// Add new rule.
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
if err != nil {
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
}
*rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
}
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
attrGet, err := pushRuleAttrGetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]interface{}{
attr: attrGet((*rulesPtr)[i]),
},
}
}
func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
var newPartialRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
}
if newPartialRule.Actions == nil {
// This ensures json.Marshal encodes the empty list as [] rather than null.
newPartialRule.Actions = []*pushrules.Action{}
}
attrGet, err := pushRuleAttrGetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
}
attrSet, err := pushRuleAttrSetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
attrSet((*rulesPtr)[i], &newPartialRule)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
var res userapi.QueryPushRulesResponse
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
return nil, err
}
return res.RuleSets, nil
}
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
req := userapi.PerformPushRulesPutRequest{
UserID: userID,
RuleSets: ruleSets,
}
var res struct{}
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
return err
}
return nil
}
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
switch scope {
case pushrules.GlobalScope:
return &ruleSets.Global
default:
return nil
}
}
func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
switch kind {
case pushrules.OverrideKind:
return &ruleSet.Override
case pushrules.ContentKind:
return &ruleSet.Content
case pushrules.RoomKind:
return &ruleSet.Room
case pushrules.SenderKind:
return &ruleSet.Sender
case pushrules.UnderrideKind:
return &ruleSet.Underride
default:
return nil
}
}
func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
for i, rule := range rules {
if rule.RuleID == id {
return i
}
}
return -1
}
func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
switch attr {
case "actions":
return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
case "enabled":
return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
default:
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
}
}
func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
switch attr {
case "actions":
return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
case "enabled":
return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
default:
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
}
}
func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
var i int
if afterID != "" {
for ; i < len(rules); i++ {
if rules[i].RuleID == afterID {
break
}
}
if i == len(rules) {
return 0, jsonerror.NotFound("after: rule ID not found")
}
if rules[i].Default {
return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
}
// We stopped on the "after" match to differentiate
// not-found from is-last-entry. Now we move to the earliest
// insertion point.
i++
}
if beforeID != "" {
for ; i < len(rules); i++ {
if rules[i].RuleID == beforeID {
break
}
}
if i == len(rules) {
return 0, jsonerror.NotFound("before: rule ID not found")
}
if rules[i].Default {
return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
}
}
// UNSPEC: The spec does not say what to do if no after/before is
// given. Sytest fails if it doesn't go first.
return i, nil
}

View file

@ -76,6 +76,10 @@ type sessionsDict struct {
sessions map[string][]authtypes.LoginType sessions map[string][]authtypes.LoginType
params map[string]registerRequest params map[string]registerRequest
timer map[string]*time.Timer timer map[string]*time.Timer
// deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused.
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
deleteSessionToDeviceID map[string]string
} }
// defaultTimeout is the timeout used to clean up sessions // defaultTimeout is the timeout used to clean up sessions
@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
defer d.Unlock() defer d.Unlock()
delete(d.params, sessionID) delete(d.params, sessionID)
delete(d.sessions, sessionID) delete(d.sessions, sessionID)
delete(d.deleteSessionToDeviceID, sessionID)
// stop the timer, e.g. because the registration was completed // stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok { if t, ok := d.timer[sessionID]; ok {
if !t.Stop() { if !t.Stop() {
@ -132,6 +137,7 @@ func newSessionsDict() *sessionsDict {
sessions: make(map[string][]authtypes.LoginType), sessions: make(map[string][]authtypes.LoginType),
params: make(map[string]registerRequest), params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer), timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
} }
} }
@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
} }
func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.deleteSessionToDeviceID[sessionID] = deviceID
}
func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
d.RLock()
defer d.RUnlock()
deviceID, ok := d.deleteSessionToDeviceID[sessionID]
return deviceID, ok
}
var ( var (
sessions = newSessionsDict() sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)

View file

@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
s := newSessionsDict() s := newSessionsDict()
t.Run("session is cleaned up after a while", func(t *testing.T) { t.Run("session is cleaned up after a while", func(t *testing.T) {
t.Parallel() // t.Parallel()
dummySession := "helloWorld" dummySession := "helloWorld"
// manually added, as s.addParams() would start the timer with the default timeout // manually added, as s.addParams() would start the timer with the default timeout
s.params[dummySession] = registerRequest{Username: "Testing"} s.params[dummySession] = registerRequest{Username: "Testing"}
s.startTimer(time.Millisecond, dummySession) s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 50)
if data, ok := s.getParams(dummySession); ok { if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data) t.Errorf("expected session to be deleted: %+v", data)
} }
}) })
t.Run("session is deleted, once the registration completed", func(t *testing.T) { t.Run("session is deleted, once the registration completed", func(t *testing.T) {
t.Parallel() // t.Parallel()
dummySession := "helloWorld2" dummySession := "helloWorld2"
s.startTimer(time.Minute, dummySession) s.startTimer(time.Minute, dummySession)
s.deleteSession(dummySession) s.deleteSession(dummySession)
@ -236,18 +236,28 @@ func TestSessionCleanUp(t *testing.T) {
}) })
t.Run("session timer is restarted after second call", func(t *testing.T) { t.Run("session timer is restarted after second call", func(t *testing.T) {
t.Parallel() // t.Parallel()
dummySession := "helloWorld3" dummySession := "helloWorld3"
// the following will start a timer with the default timeout of 5min // the following will start a timer with the default timeout of 5min
s.addParams(dummySession, registerRequest{Username: "Testing"}) s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.addDeviceToDelete(dummySession, "dummyDevice")
s.getCompletedStages(dummySession) s.getCompletedStages(dummySession)
// reset the timer with a lower timeout // reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession) s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 50)
if data, ok := s.getParams(dummySession); ok { if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data) t.Errorf("expected session to be deleted: %+v", data)
} }
if _, ok := s.timer[dummySession]; ok {
t.Error("expected timer to be delete")
}
if _, ok := s.sessions[dummySession]; ok {
t.Error("expected session to be delete")
}
if _, ok := s.getDeviceToDelete(dummySession); ok {
t.Error("expected session to device to be delete")
}
}) })
} }

View file

@ -99,7 +99,7 @@ func PutTag(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }
@ -152,7 +152,7 @@ func DeleteTag(
} }
// TODO: user API should do this since it's account data // TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }

View file

@ -16,7 +16,6 @@ package routing
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strings" "strings"
@ -581,25 +580,142 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/pushrules/", // Push rules
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API v3mux.Handle("/pushrules",
res := json.RawMessage(`{ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusBadRequest,
JSON: &res, JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
} }
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAllPushRules(req.Context(), device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
}
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
}
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
query := req.URL.Query()
return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
}),
).Methods(http.MethodDelete)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
}),
).Methods(http.MethodPut)
// Element user settings // Element user settings
v3mux.Handle("/profile/{userID}", v3mux.Handle("/profile/{userID}",
@ -905,6 +1021,27 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/notifications",
httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetNotifications(req, device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushers",
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetPushers(req, device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushers/set",
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return SetPusher(req, device, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
v3mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {

View file

@ -48,6 +48,8 @@ Example:
# read password from stdin # read password from stdin
%s --config dendrite.yaml -username alice -passwordstdin < my.pass %s --config dendrite.yaml -username alice -passwordstdin < my.pass
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
# reset password for a user, can be used with a combination above to read the password
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
Arguments: Arguments:
@ -60,12 +62,13 @@ var (
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use") askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account") isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
) )
func main() { func main() {
name := os.Args[0] name := os.Args[0]
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
flag.PrintDefaults() flag.PrintDefaults()
} }
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
@ -93,6 +96,19 @@ func main() {
if *isAdmin { if *isAdmin {
accType = api.AccountTypeAdmin accType = api.AccountTypeAdmin
} }
if *resetPassword {
err = accountDB.SetPassword(context.Background(), *username, pass)
if err != nil {
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
}
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
}
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
return
}
policyVersion := "" policyVersion := ""
if cfg.Global.UserConsentOptions.Enabled { if cfg.Global.UserConsentOptions.Enabled {
policyVersion = cfg.Global.UserConsentOptions.Version policyVersion = cfg.Global.UserConsentOptions.Version

View file

@ -144,12 +144,14 @@ func main() {
accountDB := base.Base.CreateAccountsDB() accountDB := base.Base.CreateAccountsDB()
federation := createFederationClient(base) federation := createFederationClient(base)
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsAPI := roomserver.NewInternalAPI( rsAPI := roomserver.NewInternalAPI(
&base.Base, &base.Base,
) )
userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(
&base.Base, cache.New(), userAPI, &base.Base, cache.New(), userAPI,
) )

View file

@ -187,7 +187,7 @@ func main() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -111,14 +111,15 @@ func main() {
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsComponent := roomserver.NewInternalAPI( rsComponent := roomserver.NewInternalAPI(
base, base,
) )
rsAPI := rsComponent rsAPI := rsComponent
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(
base, cache.New(), userAPI, base, cache.New(), userAPI,
) )

View file

@ -106,7 +106,8 @@ func main() {
keyAPI = base.KeyServerHTTPClient() keyAPI = base.KeyServerHTTPClient()
} }
userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) pgClient := base.PushGatewayHTTPClient()
userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
userAPI := userImpl userAPI := userImpl
if base.UseHTTPAPIs { if base.UseHTTPAPIs {
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)

View file

@ -23,7 +23,11 @@ import (
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userAPI := userapi.NewInternalAPI(
base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
base.PushGatewayHTTPClient(),
)
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)

View file

@ -184,13 +184,15 @@ func startup() {
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
federation := conn.CreateFederationClient(base, pSessions) federation := conn.CreateFederationClient(base, pSessions)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
rsAPI := roomserver.NewInternalAPI(base) rsAPI := roomserver.NewInternalAPI(base)
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
asQuery := appservice.NewInternalAPI( asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI, base, userAPI, rsAPI,

View file

@ -212,6 +212,8 @@ func main() {
rsAPI.SetFederationAPI(fedSenderAPI, keyRing) rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
psAPI := pushserver.NewInternalAPI(base)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
@ -225,6 +227,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyAPI, KeyAPI: keyAPI,
PushserverAPI: psAPI,
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider, ExtPublicRoomsProvider: p2pPublicRoomProvider,
} }

View file

@ -374,11 +374,6 @@ user_api:
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database:
connection_string: file:userapi_devices.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# The length of time that a token issued for a relying party from # The length of time that a token issued for a relying party from
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint # /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds. # is considered to be valid in milliseconds.
@ -403,9 +398,9 @@ tracing:
# Logging configuration # Logging configuration
logging: logging:
- type: std - type: std
level: info level: info
- type: file - type: file
# The logging level, must be one of debug, info, warn, error, fatal, panic. # The logging level, must be one of debug, info, warn, error, fatal, panic.
level: info level: info
params: params:

View file

@ -13,6 +13,7 @@ Group=dendrite
WorkingDirectory=/opt/dendrite/ WorkingDirectory=/opt/dendrite/
ExecStart=/opt/dendrite/bin/dendrite-monolith-server ExecStart=/opt/dendrite/bin/dendrite-monolith-server
Restart=always Restart=always
LimitNOFILE=65535
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View file

@ -21,7 +21,7 @@ type FederationClient interface {
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)

View file

@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, r) return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err

View file

@ -527,21 +527,21 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2946SpacesRequest SuggestedOnly bool
RoomID string RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish() defer span.Finish()
request := spacesReq{ request := spacesReq{
S: dst, S: dst,
Req: r, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
} }
var response spacesReq var response spacesReq

View file

@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
if err != nil { if err != nil {
ferr, ok := err.(*api.FederationClientError) ferr, ok := err.(*api.FederationClientError)
if ok { if ok {

13
go.mod
View file

@ -1,6 +1,6 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
@ -18,12 +18,13 @@ require (
github.com/frankban/quicktest v1.14.0 // indirect github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0 github.com/getsentry/sentry-go v0.12.0
github.com/gologme/log v1.3.0 github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/h2non/filetype v1.1.3 // indirect github.com/h2non/filetype v1.1.3 // indirect
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
github.com/klauspost/compress v1.14.2 // indirect
github.com/lib/pq v1.10.4 github.com/lib/pq v1.10.4
github.com/libp2p/go-libp2p v0.13.0 github.com/libp2p/go-libp2p v0.13.0
github.com/libp2p/go-libp2p-circuit v0.4.0 github.com/libp2p/go-libp2p-circuit v0.4.0
@ -39,12 +40,12 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.3.2 github.com/nats-io/nats-server/v2 v2.7.3
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
@ -61,11 +62,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.2 github.com/yggdrasil-network/yggdrasil-go v0.4.2
go.uber.org/atomic v1.9.0 go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a golang.org/x/crypto v0.0.0-20220214200702-86341886e292
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/h2non/bimg.v1 v1.1.5
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0

26
go.sum
View file

@ -480,7 +480,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U=
github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo=
github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4=
@ -712,9 +711,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4=
github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@ -983,8 +981,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -1029,8 +1027,8 @@ github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7
github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g=
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ=
github.com/minio/highwayhash v1.0.1 h1:dZ6IIu8Z14VlC0VpfKofAhCy74wu/Qb5gcn52yWoz/0= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
github.com/minio/highwayhash v1.0.1/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
@ -1132,8 +1130,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 h1:RJrc+z35RHZlrjR6UBt9UmVRAlFh4SgYyEA0YpQdPHM=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740/go.mod h1:eJUrA5gm0ch6sJTEv85xmXIgQWsB0OyjkTsKXvlHbYc=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
@ -1510,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1737,8 +1735,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=

View file

@ -10,6 +10,7 @@ const (
FederationEventCacheName = "federation_event" FederationEventCacheName = "federation_event"
FederationEventCacheMaxEntries = 256 FederationEventCacheMaxEntries = 256
FederationEventCacheMutable = true // to allow use of Unset only FederationEventCacheMutable = true // to allow use of Unset only
FederationEventCacheMaxAge = CacheNoMaxAge
) )
// FederationCache contains the subset of functions needed for // FederationCache contains the subset of functions needed for

View file

@ -1,6 +1,8 @@
package caching package caching
import ( import (
"time"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -16,6 +18,7 @@ const (
RoomInfoCacheName = "roominfo" RoomInfoCacheName = "roominfo"
RoomInfoCacheMaxEntries = 1024 RoomInfoCacheMaxEntries = 1024
RoomInfoCacheMutable = true RoomInfoCacheMutable = true
RoomInfoCacheMaxAge = time.Minute * 5
) )
// RoomInfosCache contains the subset of functions needed for // RoomInfosCache contains the subset of functions needed for

View file

@ -10,6 +10,7 @@ const (
RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false RoomServerRoomIDsCacheMutable = false
RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge
) )
type RoomServerCaches interface { type RoomServerCaches interface {

View file

@ -6,6 +6,7 @@ const (
RoomVersionCacheName = "room_versions" RoomVersionCacheName = "room_versions"
RoomVersionCacheMaxEntries = 1024 RoomVersionCacheMaxEntries = 1024
RoomVersionCacheMutable = false RoomVersionCacheMutable = false
RoomVersionCacheMaxAge = CacheNoMaxAge
) )
// RoomVersionsCache contains the subset of functions needed for // RoomVersionsCache contains the subset of functions needed for

View file

@ -10,6 +10,7 @@ const (
ServerKeyCacheName = "server_key" ServerKeyCacheName = "server_key"
ServerKeyCacheMaxEntries = 4096 ServerKeyCacheMaxEntries = 4096
ServerKeyCacheMutable = true ServerKeyCacheMutable = true
ServerKeyCacheMaxAge = CacheNoMaxAge
) )
// ServerKeyCache contains the subset of functions needed for // ServerKeyCache contains the subset of functions needed for

View file

@ -0,0 +1,33 @@
package caching
import (
"time"
"github.com/matrix-org/gomatrixserverlib"
)
const (
SpaceSummaryRoomsCacheName = "space_summary_rooms"
SpaceSummaryRoomsCacheMaxEntries = 100
SpaceSummaryRoomsCacheMutable = true
SpaceSummaryRoomsCacheMaxAge = time.Minute * 5
)
type SpaceSummaryRoomsCache interface {
GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool)
StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse)
}
func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) {
val, found := c.SpaceSummaryRooms.Get(roomID)
if found && val != nil {
if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok {
return resp, true
}
}
return r, false
}
func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) {
c.SpaceSummaryRooms.Set(roomID, r)
}

View file

@ -1,5 +1,7 @@
package caching package caching
import "time"
// Caches contains a set of references to caches. They may be // Caches contains a set of references to caches. They may be
// different implementations as long as they satisfy the Cache // different implementations as long as they satisfy the Cache
// interface. // interface.
@ -10,6 +12,7 @@ type Caches struct {
RoomServerRoomIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache FederationEvents Cache // FederationEventsCache
SpaceSummaryRooms Cache // SpaceSummaryRoomsCache
} }
// Cache is the interface that an implementation must satisfy. // Cache is the interface that an implementation must satisfy.
@ -18,3 +21,5 @@ type Cache interface {
Set(key string, value interface{}) Set(key string, value interface{})
Unset(key string) Unset(key string)
} }
const CacheNoMaxAge = time.Duration(0)

View file

@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomVersionCacheName, RoomVersionCacheName,
RoomVersionCacheMutable, RoomVersionCacheMutable,
RoomVersionCacheMaxEntries, RoomVersionCacheMaxEntries,
RoomVersionCacheMaxAge,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
ServerKeyCacheName, ServerKeyCacheName,
ServerKeyCacheMutable, ServerKeyCacheMutable,
ServerKeyCacheMaxEntries, ServerKeyCacheMaxEntries,
ServerKeyCacheMaxAge,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMutable,
RoomServerRoomIDsCacheMaxEntries, RoomServerRoomIDsCacheMaxEntries,
RoomServerRoomIDsCacheMaxAge,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomInfoCacheName, RoomInfoCacheName,
RoomInfoCacheMutable, RoomInfoCacheMutable,
RoomInfoCacheMaxEntries, RoomInfoCacheMaxEntries,
RoomInfoCacheMaxAge,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
FederationEventCacheName, FederationEventCacheName,
FederationEventCacheMutable, FederationEventCacheMutable,
FederationEventCacheMaxEntries, FederationEventCacheMaxEntries,
FederationEventCacheMaxAge,
enablePrometheus,
)
if err != nil {
return nil, err
}
spaceRooms, err := NewInMemoryLRUCachePartition(
SpaceSummaryRoomsCacheName,
SpaceSummaryRoomsCacheMutable,
SpaceSummaryRoomsCacheMaxEntries,
SpaceSummaryRoomsCacheMaxAge,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
} }
go cacheCleaner( go cacheCleaner(
roomVersions, serverKeys, roomServerRoomIDs, roomVersions, serverKeys, roomServerRoomIDs,
roomInfos, federationEvents, roomInfos, federationEvents, spaceRooms,
) )
return &Caches{ return &Caches{
RoomVersions: roomVersions, RoomVersions: roomVersions,
@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomServerRoomIDs: roomServerRoomIDs, RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos, RoomInfos: roomInfos,
FederationEvents: federationEvents, FederationEvents: federationEvents,
SpaceSummaryRooms: spaceRooms,
}, nil }, nil
} }
@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct {
name string name string
mutable bool mutable bool
maxEntries int maxEntries int
maxAge time.Duration
lru *lru.Cache lru *lru.Cache
} }
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { type inMemoryLRUCacheEntry struct {
value interface{}
created time.Time
}
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
var err error var err error
cache := InMemoryLRUCachePartition{ cache := InMemoryLRUCachePartition{
name: name, name: name,
mutable: mutable, mutable: mutable,
maxEntries: maxEntries, maxEntries: maxEntries,
maxAge: maxAge,
} }
cache.lru, err = lru.New(maxEntries) cache.lru, err = lru.New(maxEntries)
if err != nil { if err != nil {
@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
if !c.mutable { if !c.mutable {
if peek, ok := c.lru.Peek(key); ok && peek != value { if peek, ok := c.lru.Peek(key); ok {
if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value {
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
} }
} }
c.lru.Add(key, value) }
c.lru.Add(key, &inMemoryLRUCacheEntry{
value: value,
created: time.Now(),
})
} }
func (c *InMemoryLRUCachePartition) Unset(key string) { func (c *InMemoryLRUCachePartition) Unset(key string) {
@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) {
} }
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
return c.lru.Get(key) v, ok := c.lru.Get(key)
if !ok {
return nil, false
}
entry, ok := v.(*inMemoryLRUCacheEntry)
switch {
case ok && c.maxAge == CacheNoMaxAge:
return entry.value, ok // There's no maximum age policy
case ok && time.Since(entry.created) < c.maxAge:
return entry.value, ok // The value for the key isn't stale
default:
// Either the key was found and it was stale, or the key
// wasn't found at all
c.lru.Remove(key)
return nil, false
}
} }

View file

@ -28,6 +28,28 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
type AccountData struct { type AccountData struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
Type string `json:"type"` Type string `json:"type"`
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
}
type ReadMarkerJSON struct {
FullyRead string `json:"m.fully_read"`
Read string `json:"m.read"`
}
// NotificationData contains statistics about notifications, sent from
// the Push Server to the Sync API server.
type NotificationData struct {
// RoomID identifies the scope of the statistics, together with
// MXID (which is encoded in the Kafka key).
RoomID string `json:"room_id"`
// HighlightCount is the number of unread notifications with the
// highlight tweak.
UnreadHighlightCount int `json:"unread_highlight_count"`
// UnreadNotificationCount is the total number of unread
// notifications.
UnreadNotificationCount int `json:"unread_notification_count"`
} }
// ProfileResponse is a struct containing all known user profile data // ProfileResponse is a struct containing all known user profile data

View file

@ -0,0 +1,66 @@
package pushgateway
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/opentracing/opentracing-go"
)
type httpClient struct {
hc *http.Client
}
// NewHTTPClient creates a new Push Gateway client.
func NewHTTPClient(disableTLSValidation bool) Client {
hc := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: disableTLSValidation,
},
},
}
return &httpClient{hc: hc}
}
func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
defer span.Finish()
body, err := json.Marshal(req)
if err != nil {
return err
}
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
hreq.Header.Set("Content-Type", "application/json")
hresp, err := h.hc.Do(hreq)
if err != nil {
return err
}
//nolint:errcheck
defer hresp.Body.Close()
if hresp.StatusCode == http.StatusOK {
return json.NewDecoder(hresp.Body).Decode(resp)
}
var errorBody struct {
Message string `json:"message"`
}
if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
}
return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
}

View file

@ -0,0 +1,62 @@
package pushgateway
import (
"context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
// A Client is how interactions with a Push Gateway is done.
type Client interface {
// Notify sends a notification to the gateway at the given URL.
Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
}
type NotifyRequest struct {
Notification Notification `json:"notification"` // Required
}
type NotifyResponse struct {
// Rejected is the list of device push keys that were rejected
// during the push. The caller should remove the push keys so they
// are not used again.
Rejected []string `json:"rejected"` // Required
}
type Notification struct {
Content json.RawMessage `json:"content,omitempty"`
Counts *Counts `json:"counts,omitempty"`
Devices []*Device `json:"devices"` // Required
EventID string `json:"event_id,omitempty"`
ID string `json:"id,omitempty"` // Deprecated name for EventID.
Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
Prio Prio `json:"prio,omitempty"`
RoomAlias string `json:"room_alias,omitempty"`
RoomID string `json:"room_id,omitempty"`
RoomName string `json:"room_name,omitempty"`
Sender string `json:"sender,omitempty"`
SenderDisplayName string `json:"sender_display_name,omitempty"`
Type string `json:"type,omitempty"`
UserIsTarget bool `json:"user_is_target,omitempty"`
}
type Counts struct {
MissedCalls int `json:"missed_calls,omitempty"`
Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
}
type Device struct {
AppID string `json:"app_id"` // Required
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
PushKey string `json:"pushkey"` // Required
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
}
type Prio string
const (
HighPrio Prio = "high"
LowPrio Prio = "low"
)

View file

@ -0,0 +1,102 @@
package pushrules
import (
"bytes"
"encoding/json"
"fmt"
)
// An Action is (part of) an outcome of a rule. There are
// (unofficially) terminal actions, and modifier actions.
type Action struct {
// Kind is the type of action. Has custom encoding in JSON.
Kind ActionKind `json:"-"`
// Tweak is the property to tweak. Has custom encoding in JSON.
Tweak TweakKey `json:"-"`
// Value is some value interpreted according to Kind and Tweak.
Value interface{} `json:"value"`
}
func (a *Action) MarshalJSON() ([]byte, error) {
if a.Tweak == UnknownTweak && a.Value == nil {
return json.Marshal(a.Kind)
}
if a.Kind != SetTweakAction {
return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
}
m := map[string]interface{}{
string(a.Kind): a.Tweak,
}
if a.Value != nil {
m["value"] = a.Value
}
return json.Marshal(m)
}
func (a *Action) UnmarshalJSON(bs []byte) error {
if bytes.HasPrefix(bs, []byte("\"")) {
return json.Unmarshal(bs, &a.Kind)
}
var raw struct {
SetTweak TweakKey `json:"set_tweak"`
Value interface{} `json:"value"`
}
if err := json.Unmarshal(bs, &raw); err != nil {
return err
}
if raw.SetTweak == UnknownTweak {
return fmt.Errorf("got unknown action JSON: %s", string(bs))
}
a.Kind = SetTweakAction
a.Tweak = raw.SetTweak
if raw.Value != nil {
a.Value = raw.Value
}
return nil
}
// ActionKind is the primary discriminator for actions.
type ActionKind string
const (
UnknownAction ActionKind = ""
// NotifyAction indicates the clients should show a notification.
NotifyAction ActionKind = "notify"
// DontNotifyAction indicates the clients should not show a notification.
DontNotifyAction ActionKind = "dont_notify"
// CoalesceAction tells the clients to show a notification, and
// tells both servers and clients that multiple events can be
// coalesced into a single notification. The behaviour is
// implementation-specific.
CoalesceAction ActionKind = "coalesce"
// SetTweakAction uses the Tweak and Value fields to add a
// tweak. Multiple SetTweakAction can be provided in a rule,
// combined with NotifyAction or CoalesceAction.
SetTweakAction ActionKind = "set_tweak"
)
// A TweakKey describes a property to be modified/tweaked for events
// that match the rule.
type TweakKey string
const (
UnknownTweak TweakKey = ""
// SoundTweak describes which sound to play. Using "default" means
// "enable sound".
SoundTweak TweakKey = "sound"
// HighlightTweak asks the clients to highlight the conversation.
HighlightTweak TweakKey = "highlight"
)

View file

@ -0,0 +1,39 @@
package pushrules
import (
"encoding/json"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestActionJSON(t *testing.T) {
tsts := []struct {
Want Action
}{
{Action{Kind: NotifyAction}},
{Action{Kind: DontNotifyAction}},
{Action{Kind: CoalesceAction}},
{Action{Kind: SetTweakAction}},
{Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
{Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
{Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
}
for _, tst := range tsts {
t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
bs, err := json.Marshal(&tst.Want)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var got Action
if err := json.Unmarshal(bs, &got); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("+got -want:\n%s", diff)
}
})
}
}

View file

@ -0,0 +1,49 @@
package pushrules
// A Condition dictates extra conditions for a matching rules. See
// ConditionKind.
type Condition struct {
// Kind is the primary discriminator for the condition
// type. Required.
Kind ConditionKind `json:"kind"`
// Key indicates the dot-separated path of Event fields to
// match. Required for EventMatchCondition and
// SenderNotificationPermissionCondition.
Key string `json:"key,omitempty"`
// Pattern indicates the value pattern that must match. Required
// for EventMatchCondition.
Pattern string `json:"pattern,omitempty"`
// Is indicates the condition that must be fulfilled. Required for
// RoomMemberCountCondition.
Is string `json:"is,omitempty"`
}
// ConditionKind represents a kind of condition.
//
// SPEC: Unrecognised conditions MUST NOT match any events,
// effectively making the push rule disabled.
type ConditionKind string
const (
UnknownCondition ConditionKind = ""
// EventMatchCondition indicates the condition looks for a key
// path and matches a pattern. How paths that don't reference a
// simple value match against rules is implementation-specific.
EventMatchCondition ConditionKind = "event_match"
// ContainsDisplayNameCondition indicates the current user's
// display name must be found in the content body.
ContainsDisplayNameCondition ConditionKind = "contains_display_name"
// RoomMemberCountCondition matches a simple arithmetic comparison
// against the total number of members in a room.
RoomMemberCountCondition ConditionKind = "room_member_count"
// SenderNotificationPermissionCondition compares power level for
// the sender in the event's room.
SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
)

View file

@ -0,0 +1,23 @@
package pushrules
import (
"github.com/matrix-org/gomatrixserverlib"
)
// DefaultAccountRuleSets is the complete set of default push rules
// for an account.
func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
return &AccountRuleSets{
Global: *DefaultGlobalRuleSet(localpart, serverName),
}
}
// DefaultGlobalRuleSet returns the default ruleset for a given (fully
// qualified) MXID.
func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
return &RuleSet{
Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
Content: defaultContentRules(localpart),
Underride: defaultUnderrideRules,
}
}

View file

@ -0,0 +1,33 @@
package pushrules
func defaultContentRules(localpart string) []*Rule {
return []*Rule{
mRuleContainsUserNameDefinition(localpart),
}
}
const (
MRuleContainsUserName = ".m.rule.contains_user_name"
)
func mRuleContainsUserNameDefinition(localpart string) *Rule {
return &Rule{
RuleID: MRuleContainsUserName,
Default: true,
Enabled: true,
Pattern: localpart,
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: true,
},
},
}
}

View file

@ -0,0 +1,165 @@
package pushrules
func defaultOverrideRules(userID string) []*Rule {
return []*Rule{
&mRuleMasterDefinition,
&mRuleSuppressNoticesDefinition,
mRuleInviteForMeDefinition(userID),
&mRuleMemberEventDefinition,
&mRuleContainsDisplayNameDefinition,
&mRuleTombstoneDefinition,
&mRuleRoomNotifDefinition,
}
}
const (
MRuleMaster = ".m.rule.master"
MRuleSuppressNotices = ".m.rule.suppress_notices"
MRuleInviteForMe = ".m.rule.invite_for_me"
MRuleMemberEvent = ".m.rule.member_event"
MRuleContainsDisplayName = ".m.rule.contains_display_name"
MRuleTombstone = ".m.rule.tombstone"
MRuleRoomNotif = ".m.rule.roomnotif"
)
var (
mRuleMasterDefinition = Rule{
RuleID: MRuleMaster,
Default: true,
Enabled: false,
Conditions: []*Condition{},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleSuppressNoticesDefinition = Rule{
RuleID: MRuleSuppressNotices,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "content.msgtype",
Pattern: "m.notice",
},
},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleMemberEventDefinition = Rule{
RuleID: MRuleMemberEvent,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.member",
},
},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleContainsDisplayNameDefinition = Rule{
RuleID: MRuleContainsDisplayName,
Default: true,
Enabled: true,
Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: true,
},
},
}
mRuleTombstoneDefinition = Rule{
RuleID: MRuleTombstone,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.tombstone",
},
{
Kind: EventMatchCondition,
Key: "state_key",
Pattern: "",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleRoomNotifDefinition = Rule{
RuleID: MRuleRoomNotif,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "content.body",
Pattern: "@room",
},
{
Kind: SenderNotificationPermissionCondition,
Key: "room",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
)
func mRuleInviteForMeDefinition(userID string) *Rule {
return &Rule{
RuleID: MRuleInviteForMe,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.member",
},
{
Kind: EventMatchCondition,
Key: "content.membership",
Pattern: "invite",
},
{
Kind: EventMatchCondition,
Key: "state_key",
Pattern: userID,
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
}

View file

@ -0,0 +1,119 @@
package pushrules
const (
MRuleCall = ".m.rule.call"
MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
MRuleRoomOneToOne = ".m.rule.room_one_to_one"
MRuleMessage = ".m.rule.message"
MRuleEncrypted = ".m.rule.encrypted"
)
var defaultUnderrideRules = []*Rule{
&mRuleCallDefinition,
&mRuleEncryptedRoomOneToOneDefinition,
&mRuleRoomOneToOneDefinition,
&mRuleMessageDefinition,
&mRuleEncryptedDefinition,
}
var (
mRuleCallDefinition = Rule{
RuleID: MRuleCall,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.call.invite",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "ring",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleEncryptedRoomOneToOneDefinition = Rule{
RuleID: MRuleEncryptedRoomOneToOne,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: RoomMemberCountCondition,
Is: "2",
},
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.encrypted",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleRoomOneToOneDefinition = Rule{
RuleID: MRuleRoomOneToOne,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: RoomMemberCountCondition,
Is: "2",
},
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.message",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleMessageDefinition = Rule{
RuleID: MRuleMessage,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.message",
},
},
Actions: []*Action{{Kind: NotifyAction}},
}
mRuleEncryptedDefinition = Rule{
RuleID: MRuleEncrypted,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.encrypted",
},
},
Actions: []*Action{{Kind: NotifyAction}},
}
)

View file

@ -0,0 +1,165 @@
package pushrules
import (
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/gomatrixserverlib"
)
// A RuleSetEvaluator encapsulates context to evaluate an event
// against a rule set.
type RuleSetEvaluator struct {
ec EvaluationContext
ruleSet []kindAndRules
}
// An EvaluationContext gives a RuleSetEvaluator access to the
// environment, for rules that require that.
type EvaluationContext interface {
// UserDisplayName returns the current user's display name.
UserDisplayName() string
// RoomMemberCount returns the number of members in the room of
// the current event.
RoomMemberCount() (int, error)
// HasPowerLevel returns whether the user has at least the given
// power in the room of the current event.
HasPowerLevel(userID, levelKey string) (bool, error)
}
// A kindAndRules is just here to simplify iteration of the (ordered)
// kinds of rules.
type kindAndRules struct {
Kind Kind
Rules []*Rule
}
// NewRuleSetEvaluator creates a new evaluator for the given rule set.
func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
return &RuleSetEvaluator{
ec: ec,
ruleSet: []kindAndRules{
{OverrideKind, ruleSet.Override},
{ContentKind, ruleSet.Content},
{RoomKind, ruleSet.Room},
{SenderKind, ruleSet.Sender},
{UnderrideKind, ruleSet.Underride},
},
}
}
// MatchEvent returns the first matching rule. Returns nil if there
// was no match rule.
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules)
// means the ordering should be.
//
// The most reasonable interpretation is that default overrides
// still have lower priority than user content rules, so we
// iterate twice.
for _, rsat := range rse.ruleSet {
for _, defRules := range []bool{false, true} {
for _, rule := range rsat.Rules {
if rule.Default != defRules {
continue
}
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
if err != nil {
return nil, err
}
if ok {
return rule, nil
}
}
}
}
// No matching rule.
return nil, nil
}
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
if !rule.Enabled {
return false, nil
}
switch kind {
case OverrideKind, UnderrideKind:
for _, cond := range rule.Conditions {
ok, err := conditionMatches(cond, event, ec)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
case ContentKind:
// TODO: "These configure behaviour for (unencrypted) messages
// that match certain patterns." - Does that mean "content.body"?
return patternMatches("content.body", rule.Pattern, event)
case RoomKind:
return rule.RuleID == event.RoomID(), nil
case SenderKind:
return rule.RuleID == event.Sender(), nil
default:
return false, nil
}
}
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
switch cond.Kind {
case EventMatchCondition:
return patternMatches(cond.Key, cond.Pattern, event)
case ContainsDisplayNameCondition:
return patternMatches("content.body", ec.UserDisplayName(), event)
case RoomMemberCountCondition:
cmp, err := parseRoomMemberCountCondition(cond.Is)
if err != nil {
return false, fmt.Errorf("parsing room_member_count condition: %w", err)
}
n, err := ec.RoomMemberCount()
if err != nil {
return false, fmt.Errorf("RoomMemberCount failed: %w", err)
}
return cmp(n), nil
case SenderNotificationPermissionCondition:
return ec.HasPowerLevel(event.Sender(), cond.Key)
default:
return false, nil
}
}
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
re, err := globToRegexp(pattern)
if err != nil {
return false, err
}
var eventMap map[string]interface{}
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
return false, fmt.Errorf("parsing event: %w", err)
}
v, err := lookupMapPath(strings.Split(key, "."), eventMap)
if err != nil {
// An unknown path is a benign error that shouldn't stop rule
// processing. It's just a non-match.
return false, nil
}
return re.MatchString(fmt.Sprint(v)), nil
}

View file

@ -0,0 +1,189 @@
package pushrules
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/matrix-org/gomatrixserverlib"
)
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
ev := mustEventFromJSON(t, `{}`)
defaultEnabled := &Rule{
RuleID: ".default.enabled",
Default: true,
Enabled: true,
}
userEnabled := &Rule{
RuleID: ".user.enabled",
Default: false,
Enabled: true,
}
userEnabled2 := &Rule{
RuleID: ".user.enabled.2",
Default: false,
Enabled: true,
}
tsts := []struct {
Name string
RuleSet RuleSet
Want *Rule
}{
{"empty", RuleSet{}, nil},
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
{"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
{"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
{"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
{"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
{"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
{"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
got, err := rse.MatchEvent(ev)
if err != nil {
t.Fatalf("MatchEvent failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
}
})
}
}
func TestRuleMatches(t *testing.T) {
emptyRule := Rule{Enabled: true}
tsts := []struct {
Name string
Kind Kind
Rule Rule
EventJSON string
Want bool
}{
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
{"emptyContent", ContentKind, emptyRule, `{}`, false},
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{}`, true},
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
{"disabled", OverrideKind, Rule{}, `{}`, false},
{"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
{"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
if err != nil {
t.Fatalf("ruleMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
}
})
}
}
func TestConditionMatches(t *testing.T) {
tsts := []struct {
Name string
Cond Condition
EventJSON string
Want bool
}{
{"empty", Condition{}, `{}`, false},
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
if err != nil {
t.Fatalf("conditionMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
}
})
}
}
type fakeEvaluationContext struct{}
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
}
func TestPatternMatches(t *testing.T) {
tsts := []struct {
Name string
Key string
Pattern string
EventJSON string
Want bool
}{
{"empty", "", "", `{}`, false},
// Note that an empty pattern contains no wildcard characters,
// which implicitly means "*".
{"patternEmpty", "content", "", `{"content":{}}`, true},
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
{"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
{"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
{"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
if err != nil {
t.Fatalf("patternMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
}
})
}
}
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
if err != nil {
t.Fatal(err)
}
return ev
}

View file

@ -0,0 +1,71 @@
package pushrules
// An AccountRuleSets carries the rule sets associated with an
// account.
type AccountRuleSets struct {
Global RuleSet `json:"global"` // Required
}
// A RuleSet contains all the various push rules for an
// account. Listed in decreasing order of priority.
type RuleSet struct {
Override []*Rule `json:"override,omitempty"`
Content []*Rule `json:"content,omitempty"`
Room []*Rule `json:"room,omitempty"`
Sender []*Rule `json:"sender,omitempty"`
Underride []*Rule `json:"underride,omitempty"`
}
// A Rule contains matchers, conditions and final actions. While
// evaluating, at most one rule is considered matching.
//
// Kind and scope are part of the push rules request/responses, but
// not of the core data model.
type Rule struct {
// RuleID is either a free identifier, or the sender's MXID for
// SenderKind. Required.
RuleID string `json:"rule_id"`
// Default indicates whether this is a server-defined default, or
// a user-provided rule. Required.
//
// The server-default rules have the lowest priority.
Default bool `json:"default"`
// Enabled allows the user to disable rules while keeping them
// around. Required.
Enabled bool `json:"enabled"`
// Actions describe the desired outcome, should the rule
// match. Required.
Actions []*Action `json:"actions"`
// Conditions provide the rule's conditions for OverrideKind and
// UnderrideKind. Not allowed for other kinds.
Conditions []*Condition `json:"conditions"`
// Pattern is the body pattern to match for ContentKind. Required
// for that kind. The interpretation is the same as that of
// Condition.Pattern.
Pattern string `json:"pattern"`
}
// Scope only has one valid value. See also AccountRuleSets.
type Scope string
const (
UnknownScope Scope = ""
GlobalScope Scope = "global"
)
// Kind is the type of push rule. See also RuleSet.
type Kind string
const (
UnknownKind Kind = ""
OverrideKind Kind = "override"
ContentKind Kind = "content"
RoomKind Kind = "room"
SenderKind Kind = "sender"
UnderrideKind Kind = "underride"
)

125
internal/pushrules/util.go Normal file
View file

@ -0,0 +1,125 @@
package pushrules
import (
"fmt"
"regexp"
"strconv"
"strings"
)
// ActionsToTweaks converts a list of actions into a primary action
// kind and a tweaks map. Returns a nil map if it would have been
// empty.
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
var kind ActionKind
tweaks := map[string]interface{}{}
for _, a := range as {
if a.Kind == SetTweakAction {
tweaks[string(a.Tweak)] = a.Value
continue
}
if kind != UnknownAction {
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
}
kind = a.Kind
}
if len(tweaks) == 0 {
tweaks = nil
}
return kind, tweaks, nil
}
// BoolTweakOr returns the named tweak as a boolean, and returns `def`
// on failure.
func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
v, ok := tweaks[string(key)]
if !ok {
return def
}
b, ok := v.(bool)
if !ok {
return def
}
return b
}
// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
func globToRegexp(pattern string) (*regexp.Regexp, error) {
// TODO: It's unclear which glob characters are supported. The only
// place this is discussed is for the unrelated "m.policy.rule.*"
// events. Assuming, the same: /[*?]/
if !strings.ContainsAny(pattern, "*?") {
pattern = "*" + pattern + "*"
}
// The defined syntax doesn't allow escaping the glob wildcard
// characters, which makes this a straight-forward
// replace-after-quote.
pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
pattern = strings.Replace(pattern, "*", ".*", -1)
pattern = strings.Replace(pattern, "?", ".", -1)
return regexp.Compile("^(" + pattern + ")$")
}
// globNonMetaRegexp are the characters that are not considered glob
// meta-characters (i.e. may need escaping).
var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
// lookupMapPath traverses a hierarchical map structure, like the one
// produced by json.Unmarshal, to return the leaf value. Traversing
// arrays/slices is not supported, only objects/maps.
func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty path")
}
var v interface{} = m
for i, key := range path {
m, ok := v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
}
v, ok = m[key]
if !ok {
return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
}
}
return v, nil
}
// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
// into a function that checks if the argument to it fulfils the
// condition.
func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
var b int
var cmp = func(a int) bool { return a == b }
switch {
case strings.HasPrefix(s, "<="):
cmp = func(a int) bool { return a <= b }
s = s[2:]
case strings.HasPrefix(s, ">="):
cmp = func(a int) bool { return a >= b }
s = s[2:]
case strings.HasPrefix(s, "<"):
cmp = func(a int) bool { return a < b }
s = s[1:]
case strings.HasPrefix(s, ">"):
cmp = func(a int) bool { return a > b }
s = s[1:]
case strings.HasPrefix(s, "=="):
// Same cmp as the default.
s = s[2:]
}
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
b = int(v)
return cmp, nil
}

View file

@ -0,0 +1,169 @@
package pushrules
import (
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestActionsToTweaks(t *testing.T) {
tsts := []struct {
Name string
Input []*Action
WantKind ActionKind
WantTweaks map[string]interface{}
}{
{"empty", nil, UnknownAction, nil},
{"zero", []*Action{{}}, UnknownAction, nil},
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
{
"all",
[]*Action{
{Kind: CoalesceAction},
{Kind: SetTweakAction, Tweak: HighlightTweak},
{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
},
CoalesceAction,
map[string]interface{}{"highlight": nil, "sound": "default"},
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
if err != nil {
t.Fatalf("ActionsToTweaks failed: %v", err)
}
if gotKind != tst.WantKind {
t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
}
if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
t.Errorf("tweaks: +got -want:\n%s", diff)
}
})
}
}
func TestBoolTweakOr(t *testing.T) {
tsts := []struct {
Name string
Input map[string]interface{}
Def bool
Want bool
}{
{"nil", nil, false, false},
{"nilValue", map[string]interface{}{"highlight": nil}, true, true},
{"false", map[string]interface{}{"highlight": false}, true, false},
{"true", map[string]interface{}{"highlight": true}, false, true},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
if got != tst.Want {
t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
}
})
}
}
func TestGlobToRegexp(t *testing.T) {
tsts := []struct {
Input string
Want string
}{
{"", "^(.*.*)$"},
{"a", "^(.*a.*)$"},
{"a.b", "^(.*a\\.b.*)$"},
{"a?b", "^(a.b)$"},
{"a*b*", "^(a.*b.*)$"},
{"a*b?", "^(a.*b.)$"},
}
for _, tst := range tsts {
t.Run(tst.Want, func(t *testing.T) {
got, err := globToRegexp(tst.Input)
if err != nil {
t.Fatalf("globToRegexp failed: %v", err)
}
if got.String() != tst.Want {
t.Errorf("got %v, want %v", got.String(), tst.Want)
}
})
}
}
func TestLookupMapPath(t *testing.T) {
tsts := []struct {
Path []string
Root map[string]interface{}
Want interface{}
}{
{[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
{[]string{"a"}, map[string]interface{}{"a": 42}, 42},
{[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
}
for _, tst := range tsts {
t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
got, err := lookupMapPath(tst.Path, tst.Root)
if err != nil {
t.Fatalf("lookupMapPath failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("+got -want:\n%s", diff)
}
})
}
}
func TestLookupMapPathInvalid(t *testing.T) {
tsts := []struct {
Path []string
Root map[string]interface{}
}{
{nil, nil},
{[]string{"a"}, nil},
{[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
}
for _, tst := range tsts {
t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
got, err := lookupMapPath(tst.Path, tst.Root)
if err == nil {
t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
}
})
}
}
func TestParseRoomMemberCountCondition(t *testing.T) {
tsts := []struct {
Input string
WantTrue []int
WantFalse []int
}{
{"1", []int{1}, []int{0, 2}},
{"==1", []int{1}, []int{0, 2}},
{"<1", []int{0}, []int{1, 2}},
{"<=1", []int{0, 1}, []int{2}},
{">1", []int{2}, []int{0, 1}},
{">=42", []int{42, 43}, []int{41}},
}
for _, tst := range tsts {
t.Run(tst.Input, func(t *testing.T) {
got, err := parseRoomMemberCountCondition(tst.Input)
if err != nil {
t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
}
for _, v := range tst.WantTrue {
if !got(v) {
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
}
}
for _, v := range tst.WantFalse {
if got(v) {
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
}
}
})
}
}

View file

@ -0,0 +1,85 @@
package pushrules
import (
"fmt"
"regexp"
)
// ValidateRule checks the rule for errors. These follow from Sytests
// and the specification.
func ValidateRule(kind Kind, rule *Rule) []error {
var errs []error
if !validRuleIDRE.MatchString(rule.RuleID) {
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
}
if len(rule.Actions) == 0 {
errs = append(errs, fmt.Errorf("missing actions"))
}
for _, action := range rule.Actions {
errs = append(errs, validateAction(action)...)
}
for _, cond := range rule.Conditions {
errs = append(errs, validateCondition(cond)...)
}
switch kind {
case OverrideKind, UnderrideKind:
// The empty list is allowed, but for JSON-encoding reasons,
// it must not be nil.
if rule.Conditions == nil {
errs = append(errs, fmt.Errorf("missing rule conditions"))
}
case ContentKind:
if rule.Pattern == "" {
errs = append(errs, fmt.Errorf("missing content rule pattern"))
}
case RoomKind, SenderKind:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
}
return errs
}
// validRuleIDRE is a regexp for valid IDs.
//
// TODO: the specification doesn't seem to say what the rule ID syntax
// is. A Sytest fails if it contains a backslash.
var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
// validateAction returns issues with an Action.
func validateAction(action *Action) []error {
var errs []error
switch action.Kind {
case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
}
return errs
}
// validateCondition returns issues with a Condition.
func validateCondition(cond *Condition) []error {
var errs []error
switch cond.Kind {
case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
}
return errs
}

View file

@ -0,0 +1,163 @@
package pushrules
import (
"strings"
"testing"
)
func TestValidateRuleNegatives(t *testing.T) {
tsts := []struct {
Name string
Kind Kind
Rule Rule
WantErrString string
}{
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
{"noActions", OverrideKind, Rule{}, "missing actions"},
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := ValidateRule(tst.Kind, &tst.Rule)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateRulePositives(t *testing.T) {
tsts := []struct {
Name string
Kind Kind
Rule Rule
WantNoErrString string
}{
{"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
{"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
{"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
{"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
{"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
{"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
{"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
{"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := ValidateRule(tst.Kind, &tst.Rule)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}
func TestValidateActionNegatives(t *testing.T) {
tsts := []struct {
Name string
Action Action
WantErrString string
}{
{"emptyKind", Action{}, "invalid rule action kind"},
{"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateAction(&tst.Action)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateActionPositives(t *testing.T) {
tsts := []struct {
Name string
Action Action
WantNoErrString string
}{
{"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateAction(&tst.Action)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}
func TestValidateConditionNegatives(t *testing.T) {
tsts := []struct {
Name string
Condition Condition
WantErrString string
}{
{"emptyKind", Condition{}, "invalid rule condition kind"},
{"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateCondition(&tst.Condition)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateConditionPositives(t *testing.T) {
tsts := []struct {
Name string
Condition Condition
WantNoErrString string
}{
{"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateCondition(&tst.Condition)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}

View file

@ -163,6 +163,7 @@ type StatementList []struct {
func (s StatementList) Prepare(db *sql.DB) (err error) { func (s StatementList) Prepare(db *sql.DB) (err error) {
for _, statement := range s { for _, statement := range s {
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
return return
} }
} }

View file

@ -166,8 +166,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
} }
// We can't have a self-signing or user-signing key without a master // We can't have a self-signing or user-signing key without a master
// key, so make sure we have one of those. // key, so make sure we have one of those. We will also only actually do
if !hasMasterKey { // something if any of the specified keys in the request are different
// to what we've got in the database, to avoid generating key change
// notifications unnecessarily.
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -176,18 +178,43 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
return return
} }
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
}
// If we still can't find a master key for the user then stop the upload. // If we still can't find a master key for the user then stop the upload.
// This satisfies the "Fails to upload self-signing key without master key" test. // This satisfies the "Fails to upload self-signing key without master key" test.
if !hasMasterKey { if !hasMasterKey {
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: "No master key was found", Err: "No master key was found",
IsMissingParam: true, IsMissingParam: true,
} }
return return
} }
}
// Check if anything actually changed compared to what we have in the database.
changed := false
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
gomatrixserverlib.CrossSigningKeyPurposeMaster,
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
} {
old, gotOld := existingKeys[purpose]
new, gotNew := toStore[purpose]
if gotOld != gotNew {
// A new key purpose has been specified that we didn't know before,
// or one has been removed.
changed = true
break
}
if !bytes.Equal(old, new) {
// One of the existing keys for a purpose we already knew about has
// changed.
changed = true
break
}
}
if !changed {
return
}
// Store the keys. // Store the keys.
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {

View file

@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool staleUsers map[string]bool
prevIDsExist func(string, []int) bool prevIDsExist func(string, []int) bool
storedKeys []api.DeviceMessage storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers
} }
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned. // If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
d.mu.Lock()
defer d.mu.Unlock()
var result []string var result []string
for userID := range d.staleUsers { for userID, isStale := range d.staleUsers {
if !isStale {
continue
}
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID) _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,10 +81,18 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do
// MarkDeviceListStale sets the stale bit for this user to isStale. // MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
d.mu.Lock()
defer d.mu.Unlock()
d.staleUsers[userID] = isStale d.staleUsers[userID] = isStale
return nil return nil
} }
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
d.mu.Lock()
defer d.mu.Unlock()
return d.staleUsers[userID]
}
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
@ -161,7 +175,7 @@ func TestUpdateHavePrevID(t *testing.T) {
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
} }
if db.staleUsers[event.UserID] { if db.isStale(event.UserID) {
t.Errorf("%s incorrectly marked as stale", event.UserID) t.Errorf("%s incorrectly marked as stale", event.UserID)
} }
} }
@ -235,7 +249,7 @@ func TestUpdateNoPrevID(t *testing.T) {
}, },
} }
// Now we should have a fresh list and the keys and emitted something // Now we should have a fresh list and the keys and emitted something
if db.staleUsers[event.UserID] { if db.isStale(event.UserID) {
t.Errorf("%s still marked as stale", event.UserID) t.Errorf("%s still marked as stale", event.UserID)
} }
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
@ -247,3 +261,83 @@ func TestUpdateNoPrevID(t *testing.T) {
} }
} }
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
// update is still ongoing.
func TestDebounce(t *testing.T) {
t.Skipf("panic on closed channel on GHA")
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return true
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
fedCh := make(chan *http.Response, 1)
srv := gomatrixserverlib.ServerName("example.com")
userID := "@alice:example.com"
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
incomingFedReq := make(chan struct{})
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
// hit this 5 times
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
t.Errorf("ManualUpdate: %s", err)
}
}()
}
// wait until the updater hits federation
select {
case <-incomingFedReq:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for updater to hit federation")
}
// user should be marked as stale
if !db.isStale(userID) {
t.Errorf("user %s not marked as stale", userID)
}
// now send the response over federation
fedCh <- &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`
{
"user_id": "` + userID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}
close(fedCh)
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
// and should panic with read on a closed channel
wg.Wait()
// user is no longer stale now
if db.isStale(userID) {
t.Errorf("user %s is marked as stale", userID)
}
}

View file

@ -53,8 +53,7 @@ func Setup(
) { ) {
rateLimits := httputil.NewRateLimits(rateLimit) rateLimits := httputil.NewRateLimits(rateLimit)
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
v1mux := publicAPIMux.PathPrefix("/v1").Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
PathToResult: map[string]*types.ThumbnailGenerationResult{}, PathToResult: map[string]*types.ThumbnailGenerationResult{},
@ -77,21 +76,18 @@ func Setup(
} }
}) })
r0mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
activeRemoteRequests := &types.ActiveRemoteRequests{ activeRemoteRequests := &types.ActiveRemoteRequests{
MXCToResult: map[string]*types.RemoteRequestResult{}, MXCToResult: map[string]*types.RemoteRequestResult{},
} }
downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
r0mux.Handle("/thumbnail/{serverName}/{mediaId}", v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
} }

BIN
q.sqlite Normal file

Binary file not shown.

View file

@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
UserID string UserID string
OtherUserIDs []string
ExcludeRoomIDs []string ExcludeRoomIDs []string
IncludeRoomIDs []string IncludeRoomIDs []string
} }
@ -313,6 +314,9 @@ type QueryBulkStateContentResponse struct {
type QueryCurrentStateRequest struct { type QueryCurrentStateRequest struct {
RoomID string RoomID string
AllowWildcards bool
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
// state events with that event type.
StateTuples []gomatrixserverlib.StateKeyTuple StateTuples []gomatrixserverlib.StateKeyTuple
} }

View file

@ -51,12 +51,8 @@ func SendEventWithState(
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
outliers, err := state.Events(event.RoomVersion) outliers := state.Events(event.RoomVersion)
if err != nil { ires := make([]InputRoomEvent, 0, len(outliers))
return err
}
var ires []InputRoomEvent
for _, outlier := range outliers { for _, outlier := range outliers {
if haveEventIDs[outlier.EventID()] { if haveEventIDs[outlier.EventID()] {
continue continue

View file

@ -23,6 +23,21 @@ type parsedRespState struct {
StateEvents []*gomatrixserverlib.Event StateEvents []*gomatrixserverlib.Event
} }
func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
for i, event := range p.AuthEvents {
eventsByID[event.EventID()] = p.AuthEvents[i]
}
for i, event := range p.StateEvents {
eventsByID[event.EventID()] = p.StateEvents[i]
}
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
for _, event := range eventsByID {
allEvents = append(allEvents, event)
}
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
}
type missingStateReq struct { type missingStateReq struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db storage.Database
@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState(
t.hadEventsMutex.Unlock() t.hadEventsMutex.Unlock()
sendOutliers := func(resolvedState *parsedRespState) error { sendOutliers := func(resolvedState *parsedRespState) error {
outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) outliers := resolvedState.Events()
if oerr != nil { outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers))
return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr)
}
var outlierRoomEvents []api.InputRoomEvent
for _, outlier := range outliers { for _, outlier := range outliers {
if hadEvents[outlier.EventID()] { if hadEvents[outlier.EventID()] {
continue continue

View file

@ -621,6 +621,18 @@ func (r *Queryer) QueryPublishedRooms(
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
for _, tuple := range req.StateTuples { for _, tuple := range req.StateTuples {
if tuple.StateKey == "*" && req.AllowWildcards {
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
if err != nil {
return err
}
for _, e := range events {
res.StateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = e
}
} else {
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
if err != nil { if err != nil {
return err return err
@ -629,6 +641,7 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
res.StateEvents[tuple] = ev res.StateEvents[tuple] = ev
} }
} }
}
return nil return nil
} }
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
} }
roomIDs = roomIDs[:j] roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -146,13 +146,14 @@ type Database interface {
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
@ -322,13 +323,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms( func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID, roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) { ) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,7 +13,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@ -979,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
return nil, nil return nil, nil
} }
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return nil, err
}
if roomInfo == nil {
return nil, fmt.Errorf("room %s doesn't exist", roomID)
}
// e.g invited rooms
if roomInfo.IsStub {
return nil, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
if err == sql.ErrNoRows {
// No rooms have an event of this type, otherwise we'd have an event type NID
return nil, nil
}
if err != nil {
return nil, err
}
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return nil, err
}
var eventNIDs []types.EventNID
for _, e := range entries {
if e.EventTypeNID == eventTypeNID {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the events requested
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, err
}
if len(eventPairs) == 0 {
return nil, nil
}
var result []*gomatrixserverlib.HeaderedEvent
for _, pair := range eventPairs {
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
if err != nil {
return nil, err
}
result = append(result, ev.Headered(roomInfo.RoomVersion))
}
return result, nil
}
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
var membershipState tables.MembershipState var membershipState tables.MembershipState
@ -1106,13 +1161,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil return result, nil
} }
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
for id, nid := range userNIDsMap {
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1122,13 +1187,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid stateKeyNIDs[i] = nid
i++ i++
} }
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}
result := make(map[string]int, len(userNIDToCount)) result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount { for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count result[nidToUserID[nid]] = count

View file

@ -42,7 +42,8 @@ const membershipSchema = `
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
@ -296,18 +297,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
for i, v := range roomNIDs { for _, v := range roomNIDs {
iRoomNIDs[i] = v params = append(params, v)
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if txn != nil { if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) rows, err = txn.QueryContext(ctx, query, params...)
} else { } else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) rows, err = s.db.QueryContext(ctx, query, params...)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -128,9 +128,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
// counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

63
run-sytest.sh Executable file
View file

@ -0,0 +1,63 @@
#!/bin/bash
#
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
# locally, the Docker image is rebuilt first.
#
# Logs are stored in ../sytestout/logs.
set -e
set -o pipefail
main() {
local tag=buster
local base_image=debian:$tag
local runargs=()
cd "$(dirname "$0")"
if [ -d ../sytest ]; then
local tmpdir
tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
trap "rm -r '$tmpdir'" EXIT
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
echo "Re-building ../sytest Docker images..."
local status
(
cd ../sytest
docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
) &>"$tmpdir/buildlog" || status=$?
if (( status != 0 )); then
# Docker is very verbose, and we don't really care about
# building SyTest. So we accumulate and only output on
# failure.
cat "$tmpdir/buildlog" >&2
return $status
fi
fi
runargs+=( -v "$PWD/../sytest:/sytest:ro" )
fi
if [ -n "$SYTEST_POSTGRES" ]; then
runargs+=( -e POSTGRES=1 )
fi
local sytestout=$PWD/../sytestout
mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
docker run \
--rm \
--name "sytest-dendrite-${LOGNAME}" \
-e LOGS_USER=$(id -u) \
-e LOGS_GROUP=$(id -g) \
-v "$PWD:/src/:ro" \
-v "$sytestout/logs:/logs/" \
-v "$sytestout/cache/go-build:/root/.cache/go-build" \
-v "$sytestout/cache/go-pkg:/gopath/pkg" \
"${runargs[@]}" \
matrixdotorg/sytest-dendrite:latest "$@"
}
main "$@"

View file

@ -31,6 +31,7 @@ import (
sentryhttp "github.com/getsentry/sentry-go/http" sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -270,6 +271,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
return f return f
} }
// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
}
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() userdb.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {

View file

@ -205,6 +205,11 @@ user_api:
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
pusher_database:
connection_string: file:pushserver.db
max_open_conns: 100
max_idle_conns: 2
conn_max_lifetime: -1
tracing: tracing:
enabled: false enabled: false
jaeger: jaeger:

View file

@ -13,6 +13,9 @@ type UserAPI struct {
// The length of time an OpenID token is condidered valid in milliseconds // The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
// Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
// The Account database stores the login details and account information // The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI. // for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"` AccountDatabase DatabaseOptions `yaml:"account_database"`

View file

@ -18,7 +18,10 @@ var (
OutputKeyChangeEvent = "OutputKeyChangeEvent" OutputKeyChangeEvent = "OutputKeyChangeEvent"
OutputTypingEvent = "OutputTypingEvent" OutputTypingEvent = "OutputTypingEvent"
OutputClientData = "OutputClientData" OutputClientData = "OutputClientData"
OutputNotificationData = "OutputNotificationData"
OutputReceiptEvent = "OutputReceiptEvent" OutputReceiptEvent = "OutputReceiptEvent"
OutputStreamEvent = "OutputStreamEvent"
OutputReadUpdate = "OutputReadUpdate"
) )
var streams = []*nats.StreamConfig{ var streams = []*nats.StreamConfig{
@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,
Storage: nats.FileStorage, Storage: nats.FileStorage,
}, },
{
Name: OutputNotificationData,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
} }

View file

@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, m.FederationAPI, m.UserAPI, m.KeyAPI,
&m.Config.MSCs, m.ExtPublicRoomsProvider, &m.Config.MSCs,
) )
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,

View file

@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
AuthEvents: res.AuthChain, AuthEvents: res.AuthChain,
StateEvents: stateEvents, StateEvents: stateEvents,
} }
eventsInOrder, err := respState.Events(rc.roomVersion) eventsInOrder := respState.Events(rc.roomVersion)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
return
}
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG // everything gets sent as an outlier because auth chain events may be disjoint from the DAG
// as may the threaded events. // as may the threaded events.
var ires []roomserver.InputRoomEvent var ires []roomserver.InputRoomEvent
@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
}) })
} }
// we've got the data by this point so use a background context // we've got the data by this point so use a background context
err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
} }

View file

@ -18,17 +18,19 @@ package msc2946
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"net/url"
"sort"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
fs "github.com/matrix-org/dendrite/federationapi/api" fs "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
@ -40,41 +42,26 @@ import (
const ( const (
ConstCreateEventContentKey = "type" ConstCreateEventContentKey = "type"
ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child" ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent" ConstSpaceParentEventType = "m.space.parent"
) )
// Defaults sets the request defaults type MSC2946ClientResponse struct {
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"`
r.Limit = 2000 NextBatch string `json:"next_batch,omitempty"`
r.MaxRoomsPerSpace = -1
} }
// Enable this MSC // Enable this MSC
func Enable( func Enable(
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
) error { ) error {
db, err := NewDatabase(&base.Cfg.MSCs.Database) clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName))
if err != nil { base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
return fmt.Errorf("cannot enable MSC2946: %w", err) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
}
hooks.Enable()
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreReference(context.Background(), he)
if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreReference",
)
}
})
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", fedAPI := httputil.MakeExternalAPI(
httputil.MakeAuthAPI("spaces", userAPI, base.Cfg.Global.UserConsentOptions, false, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing, req, time.Now(), base.Cfg.Global.ServerName, keyRing,
@ -88,252 +75,308 @@ func Enable(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
roomID := params["roomID"] roomID := params["roomID"]
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName)
}, },
)).Methods(http.MethodPost, http.MethodOptions) )
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
return nil return nil
} }
func federatedSpacesHandler( func federatedSpacesHandler(
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string,
cache caching.SpaceSummaryRoomsCache,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
thisServer gomatrixserverlib.ServerName, thisServer gomatrixserverlib.ServerName,
) util.JSONResponse { ) util.JSONResponse {
inMemoryBatchCache := make(map[string]set) u, err := url.Parse(fedReq.RequestURI())
var r gomatrixserverlib.MSC2946SpacesRequest if err != nil {
Defaults(&r)
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: 400,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.InvalidParam("bad request uri"),
} }
} }
w := walker{ w := walker{
req: &r,
rootRoomID: roomID, rootRoomID: roomID,
serverName: fedReq.Origin(), serverName: fedReq.Origin(),
thisServer: thisServer, thisServer: thisServer,
ctx: ctx, ctx: ctx,
cache: cache,
suggestedOnly: u.Query().Get("suggested_only") == "true",
limit: 1000,
// The main difference is that it does not recurse into spaces and does not support pagination.
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
maxDepth: 1,
db: db,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI, fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache, // inline cache as we don't have pagination in federation mode
} paginationCache: make(map[string]paginationInfo),
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
} }
return w.walk()
} }
func spacesHandler( func spacesHandler(
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, rsAPI roomserver.RoomserverInternalAPI,
fsAPI fs.FederationInternalAPI,
cache caching.SpaceSummaryRoomsCache,
thisServer gomatrixserverlib.ServerName, thisServer gomatrixserverlib.ServerName,
) func(*http.Request, *userapi.Device) util.JSONResponse { ) func(*http.Request, *userapi.Device) util.JSONResponse {
// declared outside the returned handler so it persists between calls
// TODO: clear based on... time?
paginationCache := make(map[string]paginationInfo)
return func(req *http.Request, device *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
// Extract the room ID from the request. Sanity check request data. // Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req)) params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
roomID := params["roomID"] roomID := params["roomID"]
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
w := walker{ w := walker{
req: &r, suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
limit: parseInt(req.URL.Query().Get("limit"), 1000),
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
paginationToken: req.URL.Query().Get("from"),
rootRoomID: roomID, rootRoomID: roomID,
caller: device, caller: device,
thisServer: thisServer, thisServer: thisServer,
ctx: req.Context(), ctx: req.Context(),
cache: cache,
db: db,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI, fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache, paginationCache: paginationCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
} }
return w.walk()
} }
} }
type paginationInfo struct {
processed set
unvisited []roomVisit
}
type walker struct { type walker struct {
req *gomatrixserverlib.MSC2946SpacesRequest
rootRoomID string rootRoomID string
caller *userapi.Device caller *userapi.Device
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName thisServer gomatrixserverlib.ServerName
db Database
rsAPI roomserver.RoomserverInternalAPI rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationInternalAPI fsAPI fs.FederationInternalAPI
ctx context.Context ctx context.Context
cache caching.SpaceSummaryRoomsCache
suggestedOnly bool
limit int
maxDepth int
paginationToken string
// user ID|device ID|batch_num => event/room IDs sent to client paginationCache map[string]paginationInfo
inMemoryBatchCache map[string]set
mu sync.Mutex mu sync.Mutex
} }
func (w *walker) roomIsExcluded(roomID string) bool { func (w *walker) newPaginationCache() (string, paginationInfo) {
for _, exclRoom := range w.req.ExcludeRooms { p := paginationInfo{
if exclRoom == roomID { processed: make(set),
return true unvisited: nil,
} }
} tok := uuid.NewString()
return false return tok, p
} }
func (w *walker) callerID() string { func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
w.mu.Lock()
defer w.mu.Unlock()
p := w.paginationCache[paginationToken]
return &p
}
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
w.mu.Lock()
defer w.mu.Unlock()
w.paginationCache[paginationToken] = cache
}
type roomVisit struct {
roomID string
depth int
vias []string // vias to query this room by
}
func (w *walker) walk() util.JSONResponse {
if !w.authorised(w.rootRoomID) {
if w.caller != nil { if w.caller != nil {
return w.caller.UserID + "|" + w.caller.ID // CS API format
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("room is unknown/forbidden"),
} }
return string(w.serverName) } else {
} // SS API format
return util.JSONResponse{
func (w *walker) alreadySent(id string) bool { Code: 404,
w.mu.Lock() JSON: jsonerror.NotFound("room is unknown/forbidden"),
defer w.mu.Unlock()
m, ok := w.inMemoryBatchCache[w.callerID()]
if !ok {
return false
} }
return m[id]
}
func (w *walker) markSent(id string) {
w.mu.Lock()
defer w.mu.Unlock()
m := w.inMemoryBatchCache[w.callerID()]
if m == nil {
m = make(set)
} }
m[id] = true
w.inMemoryBatchCache[w.callerID()] = m
}
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
var res gomatrixserverlib.MSC2946SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
unvisited := []string{w.rootRoomID}
processed := make(set)
for len(unvisited) > 0 {
roomID := unvisited[0]
unvisited = unvisited[1:]
// If this room has already been processed, skip. NB: do not remember this between calls
if processed[roomID] || roomID == "" {
continue
} }
// Mark this room as processed.
processed[roomID] = true
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredRooms []gomatrixserverlib.MSC2946Room var discoveredRooms []gomatrixserverlib.MSC2946Room
var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull var cache *paginationInfo
// events locally if w.paginationToken != "" {
if w.roomExists(roomID) && w.authorised(roomID) { cache = w.loadPaginationCache(w.paginationToken)
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get if cache == nil {
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) return util.JSONResponse{
// this room. This requires servers to store reverse lookups. Code: 400,
events, err := w.references(roomID) JSON: jsonerror.InvalidArgumentValue("invalid from"),
if err != nil { }
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") }
} else {
tok, c := w.newPaginationCache()
cache = &c
w.paginationToken = tok
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
c.unvisited = append(c.unvisited, roomVisit{
roomID: w.rootRoomID,
depth: 0,
})
}
processed := cache.processed
unvisited := cache.unvisited
// Depth first -> stack data structure
for len(unvisited) > 0 {
if len(discoveredRooms) >= w.limit {
break
}
// pop the stack
rv := unvisited[len(unvisited)-1]
unvisited = unvisited[:len(unvisited)-1]
// If this room has already been processed, skip.
// If this room exceeds the specified depth, skip.
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
continue continue
} }
discoveredEvents = events
pubRoom := w.publicRoomsChunk(roomID) // Mark this room as processed.
roomType := "" processed.set(rv.roomID)
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
// if this room is not a space room, skip.
var roomType string
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil { if create != nil {
// escape the `.`s so gjson doesn't think it's nested // escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
} }
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. // Collect rooms/events to send back (either locally or fetched via federation)
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) {
// Get all `m.space.child` state events for this room
events, err := w.childReferences(rv.roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
continue
}
discoveredChildEvents = events
pubRoom := w.publicRoomsChunk(rv.roomID)
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
PublicRoom: *pubRoom, PublicRoom: *pubRoom,
NumRefs: len(discoveredEvents),
RoomType: roomType, RoomType: roomType,
ChildrenState: events,
}) })
} else { } else {
// attempt to query this room over federation, as either we've never heard of it before // attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless) // or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes, err := w.federatedRoomInfo(roomID) fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces")
continue continue
} }
if fedRes != nil { if fedRes != nil {
discoveredRooms = fedRes.Rooms discoveredChildEvents = fedRes.Room.ChildrenState
discoveredEvents = fedRes.Events discoveredRooms = append(discoveredRooms, fedRes.Room)
if len(fedRes.Children) > 0 {
discoveredRooms = append(discoveredRooms, fedRes.Children...)
}
// mark this room as a space room as the federated server responded.
// we need to do this so we add the children of this room to the unvisited stack
// as these children may be rooms we do know about.
roomType = ConstCreateEventContentValueSpace
} }
} }
// If this room has not ever been in `rooms` (across multiple requests), send it now // don't walk the children
for _, room := range discoveredRooms { // if the parent is not a space room
if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { if roomType != ConstCreateEventContentValueSpace {
res.Rooms = append(res.Rooms, room)
w.markSent(room.RoomID)
}
}
uniqueRooms := make(set)
// If this is the root room from the original request, insert all these events into `events` if
// they haven't been added before (across multiple requests).
if w.rootRoomID == roomID {
for _, ev := range discoveredEvents {
if !w.alreadySent(eventKey(&ev)) {
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
}
}
} else {
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
// are exceeded, stop adding events. If the event has already been added, do not add it again.
numAdded := 0
for _, ev := range discoveredEvents {
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
break
}
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
break
}
if w.alreadySent(eventKey(&ev)) {
continue continue
} }
// Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
// want to catch arrows which point to excluded rooms.
if w.roomIsExcluded(ev.RoomID) {
continue
}
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
// we don't distinguish between child state events and parent state events for the purposes of
// max_rooms_per_space, maybe we should?
numAdded++
}
}
// For each referenced room ID in the events being returned to the caller (both parent and child) // For each referenced room ID in the child events being returned to the caller
// add the room ID to the queue of unvisited rooms. Loop from the beginning. // add the room ID to the queue of unvisited rooms. Loop from the beginning.
for roomID := range uniqueRooms { // We need to invert the order here because the child events are lo->hi on the timestamp,
unvisited = append(unvisited, roomID) // so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
}{}
ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent)
unvisited = append(unvisited, roomVisit{
roomID: ev.StateKey,
depth: rv.depth + 1,
vias: spaceContent.Via,
})
} }
} }
return &res
if len(unvisited) > 0 {
// we still have more rooms so we need to send back a pagination token,
// we probably hit a room limit
cache.processed = processed
cache.unvisited = unvisited
w.storePaginationCache(w.paginationToken, *cache)
} else {
// clear the pagination token so we don't send it back to the client
// Note we do NOT nuke the cache just in case this response is lost
// and the client retries it.
w.paginationToken = ""
}
if w.caller != nil {
// return CS API format
return util.JSONResponse{
Code: 200,
JSON: MSC2946ClientResponse{
Rooms: discoveredRooms,
NextBatch: w.paginationToken,
},
}
}
// return SS API format
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
if len(discoveredRooms) == 0 {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("room is unknown/forbidden"),
}
}
return util.JSONResponse{
Code: 200,
JSON: gomatrixserverlib.MSC2946SpacesResponse{
Room: discoveredRooms[0],
Children: discoveredRooms[1:],
},
}
} }
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
@ -366,46 +409,41 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful. // unsuccessful.
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
// only do federated requests for client requests // only do federated requests for client requests
if w.caller == nil { if w.caller == nil {
return nil, nil return nil, nil
} }
// extract events which point to this room ID and extract their vias resp, ok := w.cache.GetSpaceSummary(roomID)
events, err := w.db.References(w.ctx, roomID) if ok {
if err != nil { util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID)
return nil, fmt.Errorf("failed to get References events: %w", err) return &resp, nil
} }
vias := make(set) util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias)
for _, ev := range events {
if ev.StateKeyEquals(roomID) {
// event points at this room, extract vias
content := struct {
Vias []string `json:"via"`
}{}
if err = json.Unmarshal(ev.Content(), &content); err != nil {
continue // silently ignore corrupted state events
}
for _, v := range content.Vias {
vias[v] = true
}
}
}
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
ctx := context.Background() ctx := context.Background()
// query more of the spaces graph using these servers // query more of the spaces graph using these servers
for serverName := range vias { for _, serverName := range vias {
if serverName == string(w.thisServer) { if serverName == string(w.thisServer) {
continue continue
} }
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
Limit: w.req.Limit,
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
})
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue continue
} }
// ensure nil slices are empty as we send this to the client sometimes
if res.Room.ChildrenState == nil {
res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
}
for i := 0; i < len(res.Children); i++ {
child := res.Children[i]
if child.ChildrenState == nil {
child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
}
res.Children[i] = child
}
w.cache.StoreSpaceSummary(roomID, res)
return &res, nil return &res, nil
} }
return nil, nil return nil, nil
@ -501,7 +539,7 @@ func (w *walker) authorisedUser(roomID string) bool {
hisVisEv := queryRes.StateEvents[hisVisTuple] hisVisEv := queryRes.StateEvents[hisVisTuple]
if memberEv != nil { if memberEv != nil {
membership, _ := memberEv.Membership() membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
return true return true
} }
} }
@ -514,29 +552,73 @@ func (w *walker) authorisedUser(roomID string) bool {
return false return false
} }
// references returns all references pointing to or from this room. // references returns all child references pointing to or from this room.
func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
events, err := w.db.References(w.ctx, roomID) createTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomCreate,
StateKey: "",
}
var res roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
AllowWildcards: true,
StateTuples: []gomatrixserverlib.StateKeyTuple{
createTuple, {
EventType: ConstSpaceChildEventType,
StateKey: "*",
},
},
}, &res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
for _, ev := range events { // don't return any child refs if the room is not a space room
if res.StateEvents[createTuple] != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
if roomType != ConstCreateEventContentValueSpace {
return []gomatrixserverlib.MSC2946StrippedEvent{}, nil
}
}
delete(res.StateEvents, createTuple)
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents))
for _, ev := range res.StateEvents {
content := gjson.ParseBytes(ev.Content())
// only return events that have a `via` key as per MSC1772 // only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link // else we'll incorrectly walk redacted events (as the link
// is in the state_key) // is in the state_key)
if gjson.GetBytes(ev.Content(), "via").Exists() { if content.Get("via").Exists() {
strip := stripped(ev.Event) strip := stripped(ev.Event)
if strip == nil { if strip == nil {
continue continue
} }
// if suggested only and this child isn't suggested, skip it.
// if suggested only = false we include everything so don't need to check the content.
if w.suggestedOnly && !content.Get("suggested").Bool() {
continue
}
el = append(el, *strip) el = append(el, *strip)
} }
} }
// sort by origin_server_ts as per MSC2946
sort.Slice(el, func(i, j int) bool {
return el[i].OriginServerTS < el[j].OriginServerTS
})
return el, nil return el, nil
} }
type set map[string]bool type set map[string]struct{}
func (s set) set(val string) {
s[val] = struct{}{}
}
func (s set) isSet(val string) bool {
_, ok := s[val]
return ok
}
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
if ev.StateKey() == nil { if ev.StateKey() == nil {
@ -548,6 +630,7 @@ func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEve
Content: ev.Content(), Content: ev.Content(),
Sender: ev.Sender(), Sender: ev.Sender(),
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
OriginServerTS: ev.OriginServerTS(),
} }
} }
@ -567,3 +650,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string {
} }
return "" return ""
} }
func parseInt(intstr string, defaultVal int) int {
i, err := strconv.ParseInt(intstr, 10, 32)
if err != nil {
return defaultVal
}
return int(i)
}

View file

@ -1,464 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946_test
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
client = &http.Client{
Timeout: 10 * time.Second,
}
roomVer = gomatrixserverlib.RoomVersionV6
)
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
// and a bit of recursion to subspaces. Makes a graph like:
// Root
// ____|_____
// | | |
// R1 R2 S1
// |_________
// | | |
// R3 R4 S2
// | <-- this link is just a parent, not a child
// R5
//
// Alice is not joined to R4, but R4 is "world_readable".
func TestMSC2946(t *testing.T) {
alice := "@alice:localhost"
// give access token to alice
nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device),
}
nopUserAPI.accessTokens["alice"] = userapi.Device{
AccessToken: "alice",
DisplayName: "Alice",
UserID: alice,
}
rootSpace := "!rootspace:localhost"
subSpaceS1 := "!subspaceS1:localhost"
subSpaceS2 := "!subspaceS2:localhost"
room1 := "!room1:localhost"
room2 := "!room2:localhost"
room3 := "!room3:localhost"
room4 := "!room4:localhost"
empty := ""
room5 := "!room5:localhost"
allRooms := []string{
rootSpace, subSpaceS1, subSpaceS2,
room1, room2, room3, room4, room5,
}
rootToR1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToR2 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToS1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR4 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room4,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToS2 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// This is a parent link only
s2ToR5 := mustCreateEvent(t, fledglingEvent{
RoomID: room5,
Sender: alice,
Type: msc2946.ConstSpaceParentEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// history visibility for R4
r4HisVis := mustCreateEvent(t, fledglingEvent{
RoomID: room4,
Sender: "@someone:localhost",
Type: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: &empty,
Content: map[string]interface{}{
"history_visibility": "world_readable",
},
})
var joinEvents []*gomatrixserverlib.HeaderedEvent
for _, roomID := range allRooms {
if roomID == room4 {
continue // not joined to that room
}
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
RoomID: roomID,
Sender: alice,
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
Content: map[string]interface{}{
"membership": "join",
},
}))
}
roomNameTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.name",
StateKey: "",
}
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.history_visibility",
StateKey: "",
}
nopRsAPI := &testRoomserverAPI{
joinEvents: joinEvents,
events: map[string]*gomatrixserverlib.HeaderedEvent{
rootToR1.EventID(): rootToR1,
rootToR2.EventID(): rootToR2,
rootToS1.EventID(): rootToS1,
s1ToR3.EventID(): s1ToR3,
s1ToR4.EventID(): s1ToR4,
s1ToS2.EventID(): s1ToS2,
s2ToR5.EventID(): s2ToR5,
r4HisVis.EventID(): r4HisVis,
},
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
rootSpace: {
roomNameTuple: "Root",
hisVisTuple: "shared",
},
subSpaceS1: {
roomNameTuple: "Sub-Space 1",
hisVisTuple: "joined",
},
subSpaceS2: {
roomNameTuple: "Sub-Space 2",
hisVisTuple: "shared",
},
room1: {
hisVisTuple: "joined",
},
room2: {
hisVisTuple: "joined",
},
room3: {
hisVisTuple: "joined",
},
room4: {
hisVisTuple: "world_readable",
},
room5: {
hisVisTuple: "joined",
},
},
}
allEvents := []*gomatrixserverlib.HeaderedEvent{
rootToR1, rootToR2, rootToS1,
s1ToR3, s1ToR4, s1ToS2,
s2ToR5, r4HisVis,
}
allEvents = append(allEvents, joinEvents...)
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
cancel := runServer(t, router)
defer cancel()
t.Run("returns no events for unknown rooms", func(t *testing.T) {
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
if len(res.Events) > 0 {
t.Errorf("got %d events, want 0", len(res.Events))
}
if len(res.Rooms) > 0 {
t.Errorf("got %d rooms, want 0", len(res.Rooms))
}
})
t.Run("returns the entire graph", func(t *testing.T) {
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 7 {
t.Errorf("got %d events, want 7", len(res.Events))
}
if len(res.Rooms) != len(allRooms) {
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
}
})
t.Run("can update the graph", func(t *testing.T) {
// remove R3 from the graph
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{}, // redacted
})
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 6 { // one less since we don't return redacted events
t.Errorf("got %d events, want 6", len(res.Events))
}
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
}
})
}
func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
var r gomatrixserverlib.MSC2946SpacesRequest
if err := json.Unmarshal(b, &r); err != nil {
t.Fatalf("Failed to unmarshal request: %s", err)
}
return &r
}
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{
Addr: string(":8010"),
WriteTimeout: 60 * time.Second,
Handler: router,
}
go func() {
externalServ.ListenAndServe()
}()
// wait to listen on the port
time.Sleep(500 * time.Millisecond)
return func() {
externalServ.Shutdown(context.TODO())
}
}
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
t.Helper()
var r gomatrixserverlib.MSC2946SpacesRequest
msc2946.Defaults(&r)
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %s", err)
}
httpReq, err := http.NewRequest(
"POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
bytes.NewBuffer(data),
)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if err != nil {
t.Fatalf("failed to prepare request: %s", err)
}
res, err := client.Do(httpReq)
if err != nil {
t.Fatalf("failed to do request: %s", err)
}
if res.StatusCode != expectCode {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
}
if res.StatusCode == 200 {
var result gomatrixserverlib.MSC2946SpacesResponse
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
t.Logf("Body: %s", string(body))
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
}
return &result
}
return nil
}
type testUserAPI struct {
userapi.UserInternalAPITrace
accessTokens map[string]userapi.Device
}
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
dev, ok := u.accessTokens[req.AccessToken]
if !ok {
res.Err = "unknown token"
return nil
}
res.Device = &dev
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.
roomserver.RoomserverInternalAPITrace
joinEvents []*gomatrixserverlib.HeaderedEvent
events map[string]*gomatrixserverlib.HeaderedEvent
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
}
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
res.IsInRoom = true
res.RoomExists = true
return nil
}
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
for _, roomID := range req.RoomIDs {
pubRoomData, ok := r.pubRoomState[roomID]
if ok {
res.Rooms[roomID] = pubRoomData
}
}
return nil
}
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
if he.RoomID() != req.RoomID {
return
}
if he.StateKey() == nil {
return
}
tuple := gomatrixserverlib.StateKeyTuple{
EventType: he.Type(),
StateKey: *he.StateKey(),
}
for _, t := range req.StateTuples {
if t == tuple {
res.StateEvents[t] = he
}
}
}
for _, he := range r.joinEvents {
checkEvent(he)
}
for _, he := range r.events {
checkEvent(he)
}
return nil
}
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
t.Helper()
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = "localhost"
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
cfg.MSCs.MSCs = []string{"msc2946"}
base := &base.BaseDendrite{
Cfg: cfg,
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
}
err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
if err != nil {
t.Fatalf("failed to enable MSC2946: %s", err)
}
for _, ev := range events {
hooks.Run(hooks.KindNewEventPersisted, ev)
}
return base.PublicClientAPIMux
}
type fledglingEvent struct {
Type string
StateKey *string
Content interface{}
Sender string
RoomID string
}
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
t.Helper()
seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.EventBuilder{
Sender: ev.Sender,
Depth: 999,
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: ev.RoomID,
}
err := eb.SetContent(ev.Content)
if err != nil {
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
}
// make sure the origin_server_ts changes so we can test recency
time.Sleep(1 * time.Millisecond)
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
h := signedEvent.Headered(roomVer)
return h
}

View file

@ -1,182 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
var (
relTypes = map[string]int{
ConstSpaceChildEventType: 1,
ConstSpaceParentEventType: 2,
}
)
type Database interface {
// StoreReference persists a child or parent space mapping.
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
// References returns all events which have the given roomID as a parent or child space.
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
selectEdgesStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(dbOpts)
}
return newSQLiteDatabase(dbOpts)
}
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewDummyWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
target := SpaceTarget(he)
if target == "" {
return nil // malformed event
}
relType := relTypes[he.Type()]
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
return err
}
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
for rows.Next() {
var roomVer string
var jsonBytes []byte
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
if err != nil {
return nil, err
}
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
refs = append(refs, he)
}
return refs, nil
}
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
// depending on the event type.
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
if he.StateKey() == nil {
return "" // no-op
}
switch he.Type() {
case ConstSpaceParentEventType:
return *he.StateKey()
case ConstSpaceChildEventType:
return *he.StateKey()
}
return ""
}

View file

@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er
case "msc2836": case "msc2836":
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
case "msc2946": case "msc2946":
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing) return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches)
case "msc2444": // enabled inside federationapi case "msc2444": // enabled inside federationapi
case "msc2753": // enabled inside clientapi case "msc2753": // enabled inside clientapi
default: default:

View file

@ -16,7 +16,9 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -24,9 +26,12 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,6 +44,8 @@ type OutputClientDataConsumer struct {
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -49,6 +56,7 @@ func NewOutputClientDataConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -58,6 +66,8 @@ func NewOutputClientDataConsumer(
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -100,8 +110,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -16,7 +16,9 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
@ -24,9 +26,12 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,6 +44,8 @@ type OutputReceiptEventConsumer struct {
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -50,6 +57,7 @@ func NewOutputReceiptEventConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -59,6 +67,8 @@ func NewOutputReceiptEventConsumer(
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -92,8 +102,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true return true
} }
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
} }
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer,
} }
} }
@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil return nil
} }
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -0,0 +1,110 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// OutputNotificationDataConsumer consumes events that originated in
// the Push server.
type OutputNotificationDataConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
notifier *notifier.Notifier
stream types.StreamProvider
}
// NewOutputNotificationDataConsumer creates a new consumer. Call
// Start() to begin consuming.
func NewOutputNotificationDataConsumer(
process *process.ProcessContext,
cfg *config.SyncAPI,
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputNotificationDataConsumer {
s := &OutputNotificationDataConsumer{
ctx: process.Context(),
jetstream: js,
durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
db: store,
notifier: notifier,
stream: stream,
}
return s
}
// Start starts consumption.
func (s *OutputNotificationDataConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called when the Sync server receives a new event from
// the push server. It is not safe for this function to be called from
// multiple goroutines, or else the sync stream position may race and
// be incorrectly calculated.
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
userID := string(msg.Header.Get(jetstream.UserID))
// Parse out the event JSON
var data eventutil.NotificationData
if err := json.Unmarshal(msg.Data, &data); err != nil {
sentry.CaptureException(err)
log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure")
return true
}
streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount)
if err != nil {
sentry.CaptureException(err)
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
}).WithError(err).Error("Could not save notification counts")
return false
}
s.stream.Advance(streamPos)
s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos})
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
"streamPos": streamPos,
}).Trace("Received notification data from user API")
return true
}

View file

@ -82,7 +82,16 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return to, hasNew, nil return to, hasNew, nil
} }
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
// Work out which user IDs we care about — that includes those in the original request,
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
// as well as every user who has a join or leave event in the current sync response. We
// will request information about which rooms these users are joined to, so that we can
// see if we still share any rooms with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf( util.GetLogger(ctx).Debugf(
@ -100,9 +109,8 @@ func DeviceListCatchup(
userSet[userID] = true userSet[userID] = true
} }
} }
// if the response has any join/leave events, add them now. // Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs { for _, userID := range joinUserIDs {
if !userSet[userID] { if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
@ -214,6 +222,7 @@ func filterSharedUsers(
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: userID, UserID: userID,
OtherUserIDs: usersWithChangedKeys,
}, &sharedUsersRes) }, &sharedUsersRes)
if err != nil { if err != nil {
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update

View file

@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite(
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
} }
func (n *Notifier) OnNewNotificationData(
userID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{userID}, nil, n.currPos)
}
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos

View file

@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) {
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %v", err)
} }
mustEqualPositions(t, pos, syncPositionNewEDU) mustEqualPositions(t, pos, syncPositionNewEDU)
wg.Done() wg.Done()

View file

@ -0,0 +1,62 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -0,0 +1,60 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
@ -44,7 +45,7 @@ func Context(
syncDB storage.Database, syncDB storage.Database,
roomID, eventID string, roomID, eventID string,
) util.JSONResponse { ) util.JSONResponse {
filter, err := parseContextParams(req) filter, err := parseRoomEventFilter(req)
if err != nil { if err != nil {
errMsg := "" errMsg := ""
switch err.(type) { switch err.(type) {
@ -102,6 +103,12 @@ func Context(
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)),
}
}
logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -164,7 +171,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter
return newState return newState
} }
func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
// Default room filter // Default room filter
filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} filter := &gomatrixserverlib.RoomEventFilter{Limit: 10}

View file

@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotFilter, err := parseContextParams(tt.req) gotFilter, err := parseRoomEventFilter(tt.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(gotFilter, tt.wantFilter) { if !reflect.DeepEqual(gotFilter, tt.wantFilter) {
t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter) t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
} }
}) })
} }

View file

@ -19,7 +19,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"sort" "sort"
"strconv"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -45,8 +44,8 @@ type messagesReq struct {
fromStream *types.StreamingToken fromStream *types.StreamingToken
device *userapi.Device device *userapi.Device
wasToProvided bool wasToProvided bool
limit int
backwardOrdering bool backwardOrdering bool
filter *gomatrixserverlib.RoomEventFilter
} }
type messagesResp struct { type messagesResp struct {
@ -54,10 +53,9 @@ type messagesResp struct {
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"` End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
State []gomatrixserverlib.ClientEvent `json:"state"`
} }
const defaultMessagesLimit = 10
// OnIncomingMessagesRequest implements the /messages endpoint from the // OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API. // client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
@ -83,6 +81,14 @@ func OnIncomingMessagesRequest(
} }
} }
filter, err := parseRoomEventFilter(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("unable to parse filter"),
}
}
// Extract parameters from the request's URL. // Extract parameters from the request's URL.
// Pagination tokens. // Pagination tokens.
var fromStream *types.StreamingToken var fromStream *types.StreamingToken
@ -143,18 +149,6 @@ func OnIncomingMessagesRequest(
wasToProvided = false wasToProvided = false
} }
// Maximum number of events to return; defaults to 10.
limit := defaultMessagesLimit
if len(req.URL.Query().Get("limit")) > 0 {
limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
}
}
}
// TODO: Implement filtering (#587) // TODO: Implement filtering (#587)
// Check the room ID's format. // Check the room ID's format.
@ -176,7 +170,7 @@ func OnIncomingMessagesRequest(
to: &to, to: &to,
fromStream: fromStream, fromStream: fromStream,
wasToProvided: wasToProvided, wasToProvided: wasToProvided,
limit: limit, filter: filter,
backwardOrdering: backwardOrdering, backwardOrdering: backwardOrdering,
device: device, device: device,
} }
@ -187,10 +181,27 @@ func OnIncomingMessagesRequest(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
state := []gomatrixserverlib.ClientEvent{}
if filter.LazyLoadMembers {
memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
for _, evt := range clientEvents {
memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user")
continue
}
memberShipToUser[evt.Sender] = memberShip
}
for _, evt := range memberShipToUser {
state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
}
}
util.GetLogger(req.Context()).WithFields(logrus.Fields{ util.GetLogger(req.Context()).WithFields(logrus.Fields{
"from": from.String(), "from": from.String(),
"to": to.String(), "to": to.String(),
"limit": limit, "limit": filter.Limit,
"backwards": backwardOrdering, "backwards": backwardOrdering,
"return_start": start.String(), "return_start": start.String(),
"return_end": end.String(), "return_end": end.String(),
@ -200,6 +211,7 @@ func OnIncomingMessagesRequest(
Chunk: clientEvents, Chunk: clientEvents,
Start: start.String(), Start: start.String(),
End: end.String(), End: end.String(),
State: state,
} }
if emptyFromSupplied { if emptyFromSupplied {
res.StartStream = fromStream.String() res.StartStream = fromStream.String()
@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
eventFilter := gomatrixserverlib.DefaultRoomEventFilter() eventFilter := r.filter
eventFilter.Limit = r.limit
// Retrieve the events from the local database. // Retrieve the events from the local database.
var streamEvents []types.StreamEvent var streamEvents []types.StreamEvent
if r.fromStream != nil { if r.fromStream != nil {
toStream := r.to.StreamToken() toStream := r.to.StreamToken()
streamEvents, err = r.db.GetEventsInStreamingRange( streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering, r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
) )
} else { } else {
streamEvents, err = r.db.GetEventsInTopologicalRange( streamEvents, err = r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
) )
} }
if err != nil { if err != nil {
@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
// Check if we have backward extremities for this room. // Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 { if len(backwardExtremities) > 0 {
// If so, retrieve as much events as needed through backfilling. // If so, retrieve as much events as needed through backfilling.
events, err = r.backfill(r.roomID, backwardExtremities, r.limit) events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit)
if err != nil { if err != nil {
return return
} }
@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
events []*gomatrixserverlib.HeaderedEvent, err error, events []*gomatrixserverlib.HeaderedEvent, err error,
) { ) {
// Check if we have enough events. // Check if we have enough events.
isSetLargeEnough := len(streamEvents) >= r.limit isSetLargeEnough := len(streamEvents) >= r.filter.Limit
if !isSetLargeEnough { if !isSetLargeEnough {
// it might be fine we don't have up to 'limit' events, let's find out // it might be fine we don't have up to 'limit' events, let's find out
if r.backwardOrdering { if r.backwardOrdering {
@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
var pdus []*gomatrixserverlib.HeaderedEvent var pdus []*gomatrixserverlib.HeaderedEvent
// Only ask the remote server for enough events to reach the limit. // Only ask the remote server for enough events to reach the limit.
pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents))
if err != nil { if err != nil {
return return
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -31,6 +32,7 @@ type Database interface {
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -138,6 +140,12 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)

View file

@ -0,0 +1,108 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema)
if err != nil {
return nil, err
}
r := &notificationDataStatements{}
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db)
}
type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
}
const notificationDataSchema = `
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notification_count BIGINT NOT NULL DEFAULT 0,
highlight_count BIGINT NOT NULL DEFAULT 0,
CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id)
);`
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
DO UPDATE SET notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
return
}
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err
}
roomCounts[roomID] = &eventutil.NotificationData{
RoomID: roomID,
UnreadNotificationCount: notificationCount,
UnreadHighlightCount: highlightCount,
}
}
return roomCounts, rows.Err()
}
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
var id int64
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
return id, err
}

View file

@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := streamPos var lastPos types.StreamPosition
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)

Some files were not shown because too many files have changed in this diff Show more