Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking
This commit is contained in:
commit
e6e62497c9
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
@ -63,7 +63,7 @@ jobs:
|
||||||
# Run Complement
|
# Run Complement
|
||||||
- run: |
|
- run: |
|
||||||
set -o pipefail &&
|
set -o pipefail &&
|
||||||
go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
||||||
shell: bash
|
shell: bash
|
||||||
name: Run Complement Tests
|
name: Run Complement Tests
|
||||||
env:
|
env:
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
|
||||||
# Test dependencies
|
# Test dependencies
|
||||||
test/wasm/node_modules
|
test/wasm/node_modules
|
||||||
|
|
||||||
media_store/
|
# Ignore complement folder when running locally
|
||||||
|
complement/
|
||||||
|
|
||||||
|
media_store/
|
||||||
|
|
|
@ -318,6 +318,17 @@ user_api:
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
||||||
|
# Configuration for the Push Server API.
|
||||||
|
push_server:
|
||||||
|
internal_api:
|
||||||
|
listen: http://localhost:7782
|
||||||
|
connect: http://localhost:7782
|
||||||
|
database:
|
||||||
|
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
|
||||||
|
max_open_conns: 10
|
||||||
|
max_idle_conns: 2
|
||||||
|
conn_max_lifetime: -1
|
||||||
|
|
||||||
# Configuration for Opentracing.
|
# Configuration for Opentracing.
|
||||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||||
# how this works and how to set it up.
|
# how this works and how to set it up.
|
||||||
|
|
|
@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||||
m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(m.userAPI)
|
keyAPI.SetUserAPI(m.userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
|
|
|
@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
|
|
|
@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
|
||||||
delete(u.Sessions, sessionID)
|
delete(u.Sessions, sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Challenge struct {
|
||||||
|
Completed []string `json:"completed"`
|
||||||
|
Flows []userInteractiveFlow `json:"flows"`
|
||||||
|
Session string `json:"session"`
|
||||||
|
// TODO: Return any additional `params`
|
||||||
|
Params map[string]interface{} `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: 401,
|
Code: 401,
|
||||||
JSON: struct {
|
JSON: Challenge{
|
||||||
Completed []string `json:"completed"`
|
Completed: u.Completed,
|
||||||
Flows []userInteractiveFlow `json:"flows"`
|
Flows: u.Flows,
|
||||||
Session string `json:"session"`
|
Session: sessionID,
|
||||||
// TODO: Return any additional `params`
|
Params: make(map[string]interface{}),
|
||||||
Params map[string]interface{} `json:"params"`
|
|
||||||
}{
|
|
||||||
u.Completed,
|
|
||||||
u.Flows,
|
|
||||||
sessionID,
|
|
||||||
make(map[string]interface{}),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,8 @@ func AddPublicRoutes(
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
|
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
|
||||||
accountsDB, userAPI, federation,
|
accountsDB, userAPI, federation,
|
||||||
syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg,
|
syncProducer, transactionsCache, fsAPI, keyAPI,
|
||||||
|
extRoomsProvider, mscCfg,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ type SyncAPIProducer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendData sends account data to the sync API server
|
// SendData sends account data to the sync API server
|
||||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error {
|
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
|
||||||
m := &nats.Msg{
|
m := &nats.Msg{
|
||||||
Subject: p.Topic,
|
Subject: p.Topic,
|
||||||
Header: nats.Header{},
|
Header: nats.Header{},
|
||||||
|
@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
|
||||||
m.Header.Set(jetstream.UserID, userID)
|
m.Header.Set(jetstream.UserID, userID)
|
||||||
|
|
||||||
data := eventutil.AccountData{
|
data := eventutil.AccountData{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
Type: dataType,
|
Type: dataType,
|
||||||
|
ReadMarker: readMarker,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
m.Data, err = json.Marshal(data)
|
m.Data, err = json.Marshal(data)
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
@ -127,7 +128,7 @@ func SaveAccountData(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: user API should do this since it's account data
|
// TODO: user API should do this since it's account data
|
||||||
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
|
if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -138,11 +139,6 @@ func SaveAccountData(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type readMarkerJSON struct {
|
|
||||||
FullyRead string `json:"m.fully_read"`
|
|
||||||
Read string `json:"m.read"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type fullyReadEvent struct {
|
type fullyReadEvent struct {
|
||||||
EventID string `json:"event_id"`
|
EventID string `json:"event_id"`
|
||||||
}
|
}
|
||||||
|
@ -159,7 +155,7 @@ func SaveReadMarker(
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
var r readMarkerJSON
|
var r eventutil.ReadMarkerJSON
|
||||||
resErr = httputil.UnmarshalJSONRequest(req, &r)
|
resErr = httputil.UnmarshalJSONRequest(req, &r)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
|
@ -189,7 +185,7 @@ func SaveReadMarker(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil {
|
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
|
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
|
||||||
|
@ -163,6 +164,15 @@ func DeleteDeviceById(
|
||||||
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
|
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
var (
|
||||||
|
deleteOK bool
|
||||||
|
sessionID string
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if deleteOK {
|
||||||
|
sessions.deleteSession(sessionID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
defer req.Body.Close() // nolint:errcheck
|
defer req.Body.Close() // nolint:errcheck
|
||||||
bodyBytes, err := ioutil.ReadAll(req.Body)
|
bodyBytes, err := ioutil.ReadAll(req.Body)
|
||||||
|
@ -172,8 +182,29 @@ func DeleteDeviceById(
|
||||||
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check that we know this session, and it matches with the device to delete
|
||||||
|
s := gjson.GetBytes(bodyBytes, "auth.session").Str
|
||||||
|
if dev, ok := sessions.getDeviceToDelete(s); ok {
|
||||||
|
if dev != deviceID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("session & device mismatch"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s != "" {
|
||||||
|
sessionID = s
|
||||||
|
}
|
||||||
|
|
||||||
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
|
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
|
||||||
if errRes != nil {
|
if errRes != nil {
|
||||||
|
switch data := errRes.JSON.(type) {
|
||||||
|
case auth.Challenge:
|
||||||
|
sessions.addDeviceToDelete(data.Session, deviceID)
|
||||||
|
default:
|
||||||
|
}
|
||||||
return *errRes
|
return *errRes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,6 +232,8 @@ func DeleteDeviceById(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deleteOK = true
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: struct{}{},
|
||||||
|
|
63
clientapi/routing/notification.go
Normal file
63
clientapi/routing/notification.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetNotifications handles /_matrix/client/r0/notifications
|
||||||
|
func GetNotifications(
|
||||||
|
req *http.Request, device *userapi.Device,
|
||||||
|
userAPI userapi.UserInternalAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
var limit int64
|
||||||
|
if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
|
||||||
|
var err error
|
||||||
|
limit, err = strconv.ParseInt(limitStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var queryRes userapi.QueryNotificationsResponse
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
From: req.URL.Query().Get("from"),
|
||||||
|
Limit: int(limit),
|
||||||
|
Only: req.URL.Query().Get("only"),
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: queryRes,
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ import (
|
||||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type newPasswordRequest struct {
|
type newPasswordRequest struct {
|
||||||
|
@ -37,6 +38,11 @@ func Password(
|
||||||
var r newPasswordRequest
|
var r newPasswordRequest
|
||||||
r.LogoutDevices = true
|
r.LogoutDevices = true
|
||||||
|
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"sessionId": device.SessionID,
|
||||||
|
"userId": device.UserID,
|
||||||
|
}).Debug("Changing password")
|
||||||
|
|
||||||
// Unmarshal the request.
|
// Unmarshal the request.
|
||||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
|
@ -116,6 +122,15 @@ func Password(
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
|
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pushersReq := &api.PerformPusherDeletionRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
SessionID: device.SessionID,
|
||||||
|
}
|
||||||
|
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a success code.
|
// Return a success code.
|
||||||
|
|
|
@ -286,7 +286,7 @@ func SetDisplayName(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil {
|
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
114
clientapi/routing/pusher.go
Normal file
114
clientapi/routing/pusher.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPushers handles /_matrix/client/r0/pushers
|
||||||
|
func GetPushers(
|
||||||
|
req *http.Request, device *userapi.Device,
|
||||||
|
userAPI userapi.UserInternalAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
var queryRes userapi.QueryPushersResponse
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
for i := range queryRes.Pushers {
|
||||||
|
queryRes.Pushers[i].SessionID = 0
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: queryRes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPusher handles /_matrix/client/r0/pushers/set
|
||||||
|
// This endpoint allows the creation, modification and deletion of pushers for this user ID.
|
||||||
|
// The behaviour of this endpoint varies depending on the values in the JSON body.
|
||||||
|
func SetPusher(
|
||||||
|
req *http.Request, device *userapi.Device,
|
||||||
|
userAPI userapi.UserInternalAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
body := userapi.PerformPusherSetRequest{}
|
||||||
|
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
if len(body.AppID) > 64 {
|
||||||
|
return invalidParam("length of app_id must be no more than 64 characters")
|
||||||
|
}
|
||||||
|
if len(body.PushKey) > 512 {
|
||||||
|
return invalidParam("length of pushkey must be no more than 512 bytes")
|
||||||
|
}
|
||||||
|
uInt := body.Data["url"]
|
||||||
|
if uInt != nil {
|
||||||
|
u, ok := uInt.(string)
|
||||||
|
if !ok {
|
||||||
|
return invalidParam("url must be string")
|
||||||
|
}
|
||||||
|
if u != "" {
|
||||||
|
var pushUrl *url.URL
|
||||||
|
pushUrl, err = url.Parse(u)
|
||||||
|
if err != nil {
|
||||||
|
return invalidParam("malformed url passed")
|
||||||
|
}
|
||||||
|
if pushUrl.Scheme != "https" {
|
||||||
|
return invalidParam("only https scheme is allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
body.Localpart = localpart
|
||||||
|
body.SessionID = device.SessionID
|
||||||
|
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidParam(msg string) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidParam(msg),
|
||||||
|
}
|
||||||
|
}
|
386
clientapi/routing/pushrules.go
Normal file
386
clientapi/routing/pushrules.go
Normal file
|
@ -0,0 +1,386 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
|
||||||
|
if eerr, ok := err.(*jsonerror.MatrixError); ok {
|
||||||
|
var status int
|
||||||
|
switch eerr.ErrCode {
|
||||||
|
case "M_INVALID_ARGUMENT_VALUE":
|
||||||
|
status = http.StatusBadRequest
|
||||||
|
case "M_NOT_FOUND":
|
||||||
|
status = http.StatusNotFound
|
||||||
|
default:
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
|
||||||
|
}
|
||||||
|
util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: ruleSets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: ruleSet,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: *rulesPtr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
if i < 0 {
|
||||||
|
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: (*rulesPtr)[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
var newRule pushrules.Rule
|
||||||
|
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
|
||||||
|
return errorResponse(ctx, err, "JSON Decode failed")
|
||||||
|
}
|
||||||
|
newRule.RuleID = ruleID
|
||||||
|
|
||||||
|
errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
|
||||||
|
// Modify rule at the same index.
|
||||||
|
|
||||||
|
// TODO: The spec does not say what to do in this case, but
|
||||||
|
// this feels reasonable.
|
||||||
|
*((*rulesPtr)[i]) = newRule
|
||||||
|
util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
|
||||||
|
} else {
|
||||||
|
if i >= 0 {
|
||||||
|
// Delete old rule.
|
||||||
|
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||||
|
util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
|
||||||
|
} else {
|
||||||
|
// SPEC: When creating push rules, they MUST be enabled by default.
|
||||||
|
//
|
||||||
|
// TODO: it's unclear if we must reject disabled rules, or force
|
||||||
|
// the value to true. Sytests fail if we don't force it.
|
||||||
|
newRule.Enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new rule.
|
||||||
|
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
*rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
|
||||||
|
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||||
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
if i < 0 {
|
||||||
|
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||||
|
|
||||||
|
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||||
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
attrGet, err := pushRuleAttrGetter(attr)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||||
|
}
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
if i < 0 {
|
||||||
|
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: map[string]interface{}{
|
||||||
|
attr: attrGet((*rulesPtr)[i]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||||
|
var newPartialRule pushrules.Rule
|
||||||
|
if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newPartialRule.Actions == nil {
|
||||||
|
// This ensures json.Marshal encodes the empty list as [] rather than null.
|
||||||
|
newPartialRule.Actions = []*pushrules.Action{}
|
||||||
|
}
|
||||||
|
|
||||||
|
attrGet, err := pushRuleAttrGetter(attr)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||||
|
}
|
||||||
|
attrSet, err := pushRuleAttrSetter(attr)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
|
}
|
||||||
|
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||||
|
if ruleSet == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
|
}
|
||||||
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
|
if rulesPtr == nil {
|
||||||
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
|
}
|
||||||
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
if i < 0 {
|
||||||
|
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
|
||||||
|
attrSet((*rulesPtr)[i], &newPartialRule)
|
||||||
|
|
||||||
|
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||||
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
|
||||||
|
var res userapi.QueryPushRulesResponse
|
||||||
|
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.RuleSets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
|
||||||
|
req := userapi.PerformPushRulesPutRequest{
|
||||||
|
UserID: userID,
|
||||||
|
RuleSets: ruleSets,
|
||||||
|
}
|
||||||
|
var res struct{}
|
||||||
|
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
|
||||||
|
switch scope {
|
||||||
|
case pushrules.GlobalScope:
|
||||||
|
return &ruleSets.Global
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
|
||||||
|
switch kind {
|
||||||
|
case pushrules.OverrideKind:
|
||||||
|
return &ruleSet.Override
|
||||||
|
case pushrules.ContentKind:
|
||||||
|
return &ruleSet.Content
|
||||||
|
case pushrules.RoomKind:
|
||||||
|
return &ruleSet.Room
|
||||||
|
case pushrules.SenderKind:
|
||||||
|
return &ruleSet.Sender
|
||||||
|
case pushrules.UnderrideKind:
|
||||||
|
return &ruleSet.Underride
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
|
||||||
|
for i, rule := range rules {
|
||||||
|
if rule.RuleID == id {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
|
||||||
|
switch attr {
|
||||||
|
case "actions":
|
||||||
|
return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
|
||||||
|
case "enabled":
|
||||||
|
return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
|
||||||
|
default:
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
|
||||||
|
switch attr {
|
||||||
|
case "actions":
|
||||||
|
return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
|
||||||
|
case "enabled":
|
||||||
|
return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
|
||||||
|
default:
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
|
||||||
|
var i int
|
||||||
|
|
||||||
|
if afterID != "" {
|
||||||
|
for ; i < len(rules); i++ {
|
||||||
|
if rules[i].RuleID == afterID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i == len(rules) {
|
||||||
|
return 0, jsonerror.NotFound("after: rule ID not found")
|
||||||
|
}
|
||||||
|
if rules[i].Default {
|
||||||
|
return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
|
||||||
|
}
|
||||||
|
// We stopped on the "after" match to differentiate
|
||||||
|
// not-found from is-last-entry. Now we move to the earliest
|
||||||
|
// insertion point.
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if beforeID != "" {
|
||||||
|
for ; i < len(rules); i++ {
|
||||||
|
if rules[i].RuleID == beforeID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i == len(rules) {
|
||||||
|
return 0, jsonerror.NotFound("before: rule ID not found")
|
||||||
|
}
|
||||||
|
if rules[i].Default {
|
||||||
|
return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UNSPEC: The spec does not say what to do if no after/before is
|
||||||
|
// given. Sytest fails if it doesn't go first.
|
||||||
|
return i, nil
|
||||||
|
}
|
|
@ -76,6 +76,10 @@ type sessionsDict struct {
|
||||||
sessions map[string][]authtypes.LoginType
|
sessions map[string][]authtypes.LoginType
|
||||||
params map[string]registerRequest
|
params map[string]registerRequest
|
||||||
timer map[string]*time.Timer
|
timer map[string]*time.Timer
|
||||||
|
// deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused.
|
||||||
|
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||||
|
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||||
|
deleteSessionToDeviceID map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultTimeout is the timeout used to clean up sessions
|
// defaultTimeout is the timeout used to clean up sessions
|
||||||
|
@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
delete(d.params, sessionID)
|
delete(d.params, sessionID)
|
||||||
delete(d.sessions, sessionID)
|
delete(d.sessions, sessionID)
|
||||||
|
delete(d.deleteSessionToDeviceID, sessionID)
|
||||||
// stop the timer, e.g. because the registration was completed
|
// stop the timer, e.g. because the registration was completed
|
||||||
if t, ok := d.timer[sessionID]; ok {
|
if t, ok := d.timer[sessionID]; ok {
|
||||||
if !t.Stop() {
|
if !t.Stop() {
|
||||||
|
@ -129,9 +134,10 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
||||||
|
|
||||||
func newSessionsDict() *sessionsDict {
|
func newSessionsDict() *sessionsDict {
|
||||||
return &sessionsDict{
|
return &sessionsDict{
|
||||||
sessions: make(map[string][]authtypes.LoginType),
|
sessions: make(map[string][]authtypes.LoginType),
|
||||||
params: make(map[string]registerRequest),
|
params: make(map[string]registerRequest),
|
||||||
timer: make(map[string]*time.Timer),
|
timer: make(map[string]*time.Timer),
|
||||||
|
deleteSessionToDeviceID: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype
|
||||||
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) {
|
||||||
|
d.startTimer(defaultTimeOut, sessionID)
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
d.deleteSessionToDeviceID[sessionID] = deviceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
||||||
|
d.RLock()
|
||||||
|
defer d.RUnlock()
|
||||||
|
deviceID, ok := d.deleteSessionToDeviceID[sessionID]
|
||||||
|
return deviceID, ok
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sessions = newSessionsDict()
|
sessions = newSessionsDict()
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||||
|
|
|
@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
s := newSessionsDict()
|
s := newSessionsDict()
|
||||||
|
|
||||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||||
t.Parallel()
|
// t.Parallel()
|
||||||
dummySession := "helloWorld"
|
dummySession := "helloWorld"
|
||||||
// manually added, as s.addParams() would start the timer with the default timeout
|
// manually added, as s.addParams() would start the timer with the default timeout
|
||||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||||
s.startTimer(time.Millisecond, dummySession)
|
s.startTimer(time.Millisecond, dummySession)
|
||||||
time.Sleep(time.Millisecond * 5)
|
time.Sleep(time.Millisecond * 50)
|
||||||
if data, ok := s.getParams(dummySession); ok {
|
if data, ok := s.getParams(dummySession); ok {
|
||||||
t.Errorf("expected session to be deleted: %+v", data)
|
t.Errorf("expected session to be deleted: %+v", data)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||||
t.Parallel()
|
// t.Parallel()
|
||||||
dummySession := "helloWorld2"
|
dummySession := "helloWorld2"
|
||||||
s.startTimer(time.Minute, dummySession)
|
s.startTimer(time.Minute, dummySession)
|
||||||
s.deleteSession(dummySession)
|
s.deleteSession(dummySession)
|
||||||
|
@ -236,18 +236,28 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||||
t.Parallel()
|
// t.Parallel()
|
||||||
dummySession := "helloWorld3"
|
dummySession := "helloWorld3"
|
||||||
// the following will start a timer with the default timeout of 5min
|
// the following will start a timer with the default timeout of 5min
|
||||||
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
||||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
||||||
|
s.addDeviceToDelete(dummySession, "dummyDevice")
|
||||||
s.getCompletedStages(dummySession)
|
s.getCompletedStages(dummySession)
|
||||||
// reset the timer with a lower timeout
|
// reset the timer with a lower timeout
|
||||||
s.startTimer(time.Millisecond, dummySession)
|
s.startTimer(time.Millisecond, dummySession)
|
||||||
time.Sleep(time.Millisecond * 5)
|
time.Sleep(time.Millisecond * 50)
|
||||||
if data, ok := s.getParams(dummySession); ok {
|
if data, ok := s.getParams(dummySession); ok {
|
||||||
t.Errorf("expected session to be deleted: %+v", data)
|
t.Errorf("expected session to be deleted: %+v", data)
|
||||||
}
|
}
|
||||||
|
if _, ok := s.timer[dummySession]; ok {
|
||||||
|
t.Error("expected timer to be delete")
|
||||||
|
}
|
||||||
|
if _, ok := s.sessions[dummySession]; ok {
|
||||||
|
t.Error("expected session to be delete")
|
||||||
|
}
|
||||||
|
if _, ok := s.getDeviceToDelete(dummySession); ok {
|
||||||
|
t.Error("expected session to device to be delete")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
|
@ -99,7 +99,7 @@ func PutTag(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ func DeleteTag(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: user API should do this since it's account data
|
// TODO: user API should do this since it's account data
|
||||||
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -581,25 +580,142 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/pushrules/",
|
// Push rules
|
||||||
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
|
|
||||||
// TODO: Implement push rules API
|
v3mux.Handle("/pushrules",
|
||||||
res := json.RawMessage(`{
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
"global": {
|
|
||||||
"content": [],
|
|
||||||
"override": [],
|
|
||||||
"room": [],
|
|
||||||
"sender": [],
|
|
||||||
"underride": []
|
|
||||||
}
|
|
||||||
}`)
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusBadRequest,
|
||||||
JSON: &res,
|
JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return GetAllPushRules(req.Context(), device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.Limit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
query := req.URL.Query()
|
||||||
|
return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodDelete)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||||
|
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
// Element user settings
|
// Element user settings
|
||||||
|
|
||||||
v3mux.Handle("/profile/{userID}",
|
v3mux.Handle("/profile/{userID}",
|
||||||
|
@ -905,6 +1021,27 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/notifications",
|
||||||
|
httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return GetNotifications(req, device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushers",
|
||||||
|
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return GetPushers(req, device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/pushers/set",
|
||||||
|
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.Limit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
return SetPusher(req, device, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
// Stub implementations for sytest
|
// Stub implementations for sytest
|
||||||
v3mux.Handle("/events",
|
v3mux.Handle("/events",
|
||||||
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
@ -48,24 +48,27 @@ Example:
|
||||||
# read password from stdin
|
# read password from stdin
|
||||||
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
|
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
|
||||||
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
|
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
|
||||||
|
# reset password for a user, can be used with a combination above to read the password
|
||||||
|
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
var (
|
var (
|
||||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||||
|
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
name := os.Args[0]
|
name := os.Args[0]
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
cfg := setup.ParseFlags(true)
|
cfg := setup.ParseFlags(true)
|
||||||
|
@ -93,6 +96,19 @@ func main() {
|
||||||
if *isAdmin {
|
if *isAdmin {
|
||||||
accType = api.AccountTypeAdmin
|
accType = api.AccountTypeAdmin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *resetPassword {
|
||||||
|
err = accountDB.SetPassword(context.Background(), *username, pass)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
|
||||||
|
}
|
||||||
|
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
|
||||||
|
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
|
||||||
|
}
|
||||||
|
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
policyVersion := ""
|
policyVersion := ""
|
||||||
if cfg.Global.UserConsentOptions.Enabled {
|
if cfg.Global.UserConsentOptions.Enabled {
|
||||||
policyVersion = cfg.Global.UserConsentOptions.Version
|
policyVersion = cfg.Global.UserConsentOptions.Version
|
||||||
|
|
|
@ -144,12 +144,14 @@ func main() {
|
||||||
accountDB := base.Base.CreateAccountsDB()
|
accountDB := base.Base.CreateAccountsDB()
|
||||||
federation := createFederationClient(base)
|
federation := createFederationClient(base)
|
||||||
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
|
||||||
keyAPI.SetUserAPI(userAPI)
|
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(
|
rsAPI := roomserver.NewInternalAPI(
|
||||||
&base.Base,
|
&base.Base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
&base.Base, cache.New(), userAPI,
|
&base.Base, cache.New(), userAPI,
|
||||||
)
|
)
|
||||||
|
|
|
@ -187,7 +187,7 @@ func main() {
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
|
|
|
@ -111,14 +111,15 @@ func main() {
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
|
||||||
keyAPI.SetUserAPI(userAPI)
|
|
||||||
|
|
||||||
rsComponent := roomserver.NewInternalAPI(
|
rsComponent := roomserver.NewInternalAPI(
|
||||||
base,
|
base,
|
||||||
)
|
)
|
||||||
rsAPI := rsComponent
|
rsAPI := rsComponent
|
||||||
|
|
||||||
|
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
base, cache.New(), userAPI,
|
base, cache.New(), userAPI,
|
||||||
)
|
)
|
||||||
|
|
|
@ -106,7 +106,8 @@ func main() {
|
||||||
keyAPI = base.KeyServerHTTPClient()
|
keyAPI = base.KeyServerHTTPClient()
|
||||||
}
|
}
|
||||||
|
|
||||||
userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
pgClient := base.PushGatewayHTTPClient()
|
||||||
|
userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
|
||||||
userAPI := userImpl
|
userAPI := userImpl
|
||||||
if base.UseHTTPAPIs {
|
if base.UseHTTPAPIs {
|
||||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||||
|
|
|
@ -23,7 +23,11 @@ import (
|
||||||
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
||||||
accountDB := base.CreateAccountsDB()
|
accountDB := base.CreateAccountsDB()
|
||||||
|
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
|
userAPI := userapi.NewInternalAPI(
|
||||||
|
base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
|
||||||
|
base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
|
||||||
|
base.PushGatewayHTTPClient(),
|
||||||
|
)
|
||||||
|
|
||||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||||
|
|
||||||
|
|
|
@ -184,13 +184,15 @@ func startup() {
|
||||||
accountDB := base.CreateAccountsDB()
|
accountDB := base.CreateAccountsDB()
|
||||||
federation := conn.CreateFederationClient(base, pSessions)
|
federation := conn.CreateFederationClient(base, pSessions)
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
|
||||||
keyAPI.SetUserAPI(userAPI)
|
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
|
||||||
|
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
||||||
asQuery := appservice.NewInternalAPI(
|
asQuery := appservice.NewInternalAPI(
|
||||||
base, userAPI, rsAPI,
|
base, userAPI, rsAPI,
|
||||||
|
|
|
@ -212,6 +212,8 @@ func main() {
|
||||||
rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
|
rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
|
||||||
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
|
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
|
||||||
|
|
||||||
|
psAPI := pushserver.NewInternalAPI(base)
|
||||||
|
|
||||||
monolith := setup.Monolith{
|
monolith := setup.Monolith{
|
||||||
Config: base.Cfg,
|
Config: base.Cfg,
|
||||||
AccountDB: accountDB,
|
AccountDB: accountDB,
|
||||||
|
@ -225,6 +227,7 @@ func main() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
KeyAPI: keyAPI,
|
KeyAPI: keyAPI,
|
||||||
|
PushserverAPI: psAPI,
|
||||||
//ServerKeyAPI: serverKeyAPI,
|
//ServerKeyAPI: serverKeyAPI,
|
||||||
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,8 +61,8 @@ global:
|
||||||
# Lists of domains that the server will trust as identity servers to verify third
|
# Lists of domains that the server will trust as identity servers to verify third
|
||||||
# party identifiers such as phone numbers and email addresses.
|
# party identifiers such as phone numbers and email addresses.
|
||||||
trusted_third_party_id_servers:
|
trusted_third_party_id_servers:
|
||||||
- matrix.org
|
- matrix.org
|
||||||
- vector.im
|
- vector.im
|
||||||
|
|
||||||
# Disables federation. Dendrite will not be able to make any outbound HTTP requests
|
# Disables federation. Dendrite will not be able to make any outbound HTTP requests
|
||||||
# to other servers and the federation API will not be exposed.
|
# to other servers and the federation API will not be exposed.
|
||||||
|
@ -116,7 +116,7 @@ global:
|
||||||
# in monolith mode. It is required to specify the address of at least one
|
# in monolith mode. It is required to specify the address of at least one
|
||||||
# NATS Server node if running in polylith mode.
|
# NATS Server node if running in polylith mode.
|
||||||
addresses:
|
addresses:
|
||||||
# - localhost:4222
|
# - localhost:4222
|
||||||
|
|
||||||
# Keep all NATS streams in memory, rather than persisting it to the storage
|
# Keep all NATS streams in memory, rather than persisting it to the storage
|
||||||
# path below. This option is present primarily for integration testing and
|
# path below. This option is present primarily for integration testing and
|
||||||
|
@ -155,7 +155,7 @@ global:
|
||||||
# Configuration for the Appservice API.
|
# Configuration for the Appservice API.
|
||||||
app_service_api:
|
app_service_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7777 # Only used in polylith deployments
|
listen: http://localhost:7777 # Only used in polylith deployments
|
||||||
connect: http://localhost:7777 # Only used in polylith deployments
|
connect: http://localhost:7777 # Only used in polylith deployments
|
||||||
database:
|
database:
|
||||||
connection_string: file:appservice.db
|
connection_string: file:appservice.db
|
||||||
|
@ -174,7 +174,7 @@ app_service_api:
|
||||||
# Configuration for the Client API.
|
# Configuration for the Client API.
|
||||||
client_api:
|
client_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7771 # Only used in polylith deployments
|
listen: http://localhost:7771 # Only used in polylith deployments
|
||||||
connect: http://localhost:7771 # Only used in polylith deployments
|
connect: http://localhost:7771 # Only used in polylith deployments
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8071
|
listen: http://[::]:8071
|
||||||
|
@ -219,13 +219,13 @@ client_api:
|
||||||
# Configuration for the EDU server.
|
# Configuration for the EDU server.
|
||||||
edu_server:
|
edu_server:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7778 # Only used in polylith deployments
|
listen: http://localhost:7778 # Only used in polylith deployments
|
||||||
connect: http://localhost:7778 # Only used in polylith deployments
|
connect: http://localhost:7778 # Only used in polylith deployments
|
||||||
|
|
||||||
# Configuration for the Federation API.
|
# Configuration for the Federation API.
|
||||||
federation_api:
|
federation_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7772 # Only used in polylith deployments
|
listen: http://localhost:7772 # Only used in polylith deployments
|
||||||
connect: http://localhost:7772 # Only used in polylith deployments
|
connect: http://localhost:7772 # Only used in polylith deployments
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8072
|
listen: http://[::]:8072
|
||||||
|
@ -253,12 +253,12 @@ federation_api:
|
||||||
# be required to satisfy key requests for servers that are no longer online when
|
# be required to satisfy key requests for servers that are no longer online when
|
||||||
# joining some rooms.
|
# joining some rooms.
|
||||||
key_perspectives:
|
key_perspectives:
|
||||||
- server_name: matrix.org
|
- server_name: matrix.org
|
||||||
keys:
|
keys:
|
||||||
- key_id: ed25519:auto
|
- key_id: ed25519:auto
|
||||||
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
|
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
|
||||||
- key_id: ed25519:a_RXGa
|
- key_id: ed25519:a_RXGa
|
||||||
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
||||||
|
|
||||||
# This option will control whether Dendrite will prefer to look up keys directly
|
# This option will control whether Dendrite will prefer to look up keys directly
|
||||||
# or whether it should try perspective servers first, using direct fetches as a
|
# or whether it should try perspective servers first, using direct fetches as a
|
||||||
|
@ -268,7 +268,7 @@ federation_api:
|
||||||
# Configuration for the Key Server (for end-to-end encryption).
|
# Configuration for the Key Server (for end-to-end encryption).
|
||||||
key_server:
|
key_server:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7779 # Only used in polylith deployments
|
listen: http://localhost:7779 # Only used in polylith deployments
|
||||||
connect: http://localhost:7779 # Only used in polylith deployments
|
connect: http://localhost:7779 # Only used in polylith deployments
|
||||||
database:
|
database:
|
||||||
connection_string: file:keyserver.db
|
connection_string: file:keyserver.db
|
||||||
|
@ -279,7 +279,7 @@ key_server:
|
||||||
# Configuration for the Media API.
|
# Configuration for the Media API.
|
||||||
media_api:
|
media_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7774 # Only used in polylith deployments
|
listen: http://localhost:7774 # Only used in polylith deployments
|
||||||
connect: http://localhost:7774 # Only used in polylith deployments
|
connect: http://localhost:7774 # Only used in polylith deployments
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8074
|
listen: http://[::]:8074
|
||||||
|
@ -305,15 +305,15 @@ media_api:
|
||||||
|
|
||||||
# A list of thumbnail sizes to be generated for media content.
|
# A list of thumbnail sizes to be generated for media content.
|
||||||
thumbnail_sizes:
|
thumbnail_sizes:
|
||||||
- width: 32
|
- width: 32
|
||||||
height: 32
|
height: 32
|
||||||
method: crop
|
method: crop
|
||||||
- width: 96
|
- width: 96
|
||||||
height: 96
|
height: 96
|
||||||
method: crop
|
method: crop
|
||||||
- width: 640
|
- width: 640
|
||||||
height: 480
|
height: 480
|
||||||
method: scale
|
method: scale
|
||||||
|
|
||||||
# Configuration for experimental MSC's
|
# Configuration for experimental MSC's
|
||||||
mscs:
|
mscs:
|
||||||
|
@ -331,7 +331,7 @@ mscs:
|
||||||
# Configuration for the Room Server.
|
# Configuration for the Room Server.
|
||||||
room_server:
|
room_server:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7770 # Only used in polylith deployments
|
listen: http://localhost:7770 # Only used in polylith deployments
|
||||||
connect: http://localhost:7770 # Only used in polylith deployments
|
connect: http://localhost:7770 # Only used in polylith deployments
|
||||||
database:
|
database:
|
||||||
connection_string: file:roomserver.db
|
connection_string: file:roomserver.db
|
||||||
|
@ -342,7 +342,7 @@ room_server:
|
||||||
# Configuration for the Sync API.
|
# Configuration for the Sync API.
|
||||||
sync_api:
|
sync_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7773 # Only used in polylith deployments
|
listen: http://localhost:7773 # Only used in polylith deployments
|
||||||
connect: http://localhost:7773 # Only used in polylith deployments
|
connect: http://localhost:7773 # Only used in polylith deployments
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8073
|
listen: http://[::]:8073
|
||||||
|
@ -367,18 +367,13 @@ user_api:
|
||||||
# This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds)
|
# This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds)
|
||||||
# bcrypt_cost: 10
|
# bcrypt_cost: 10
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7781 # Only used in polylith deployments
|
listen: http://localhost:7781 # Only used in polylith deployments
|
||||||
connect: http://localhost:7781 # Only used in polylith deployments
|
connect: http://localhost:7781 # Only used in polylith deployments
|
||||||
account_database:
|
account_database:
|
||||||
connection_string: file:userapi_accounts.db
|
connection_string: file:userapi_accounts.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
device_database:
|
|
||||||
connection_string: file:userapi_devices.db
|
|
||||||
max_open_conns: 10
|
|
||||||
max_idle_conns: 2
|
|
||||||
conn_max_lifetime: -1
|
|
||||||
# The length of time that a token issued for a relying party from
|
# The length of time that a token issued for a relying party from
|
||||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||||
# is considered to be valid in milliseconds.
|
# is considered to be valid in milliseconds.
|
||||||
|
@ -403,10 +398,10 @@ tracing:
|
||||||
|
|
||||||
# Logging configuration
|
# Logging configuration
|
||||||
logging:
|
logging:
|
||||||
- type: std
|
- type: std
|
||||||
level: info
|
level: info
|
||||||
- type: file
|
- type: file
|
||||||
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
||||||
level: info
|
level: info
|
||||||
params:
|
params:
|
||||||
path: ./logs
|
path: ./logs
|
||||||
|
|
|
@ -13,6 +13,7 @@ Group=dendrite
|
||||||
WorkingDirectory=/opt/dendrite/
|
WorkingDirectory=/opt/dendrite/
|
||||||
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
|
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
|
||||||
Restart=always
|
Restart=always
|
||||||
|
LimitNOFILE=65535
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
|
|
|
@ -21,7 +21,7 @@ type FederationClient interface {
|
||||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||||
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
||||||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||||
|
|
|
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *FederationInternalAPI) MSC2946Spaces(
|
func (a *FederationInternalAPI) MSC2946Spaces(
|
||||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
|
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
|
|
|
@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
|
||||||
}
|
}
|
||||||
|
|
||||||
type spacesReq struct {
|
type spacesReq struct {
|
||||||
S gomatrixserverlib.ServerName
|
S gomatrixserverlib.ServerName
|
||||||
Req gomatrixserverlib.MSC2946SpacesRequest
|
SuggestedOnly bool
|
||||||
RoomID string
|
RoomID string
|
||||||
Res gomatrixserverlib.MSC2946SpacesResponse
|
Res gomatrixserverlib.MSC2946SpacesResponse
|
||||||
Err *api.FederationClientError
|
Err *api.FederationClientError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpFederationInternalAPI) MSC2946Spaces(
|
func (h *httpFederationInternalAPI) MSC2946Spaces(
|
||||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||||
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
|
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
request := spacesReq{
|
request := spacesReq{
|
||||||
S: dst,
|
S: dst,
|
||||||
Req: r,
|
SuggestedOnly: suggestedOnly,
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
}
|
}
|
||||||
var response spacesReq
|
var response spacesReq
|
||||||
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
|
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
|
||||||
|
|
|
@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
|
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ferr, ok := err.(*api.FederationClientError)
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
if ok {
|
if ok {
|
||||||
|
|
13
go.mod
13
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module github.com/matrix-org/dendrite
|
module github.com/matrix-org/dendrite
|
||||||
|
|
||||||
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad
|
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740
|
||||||
|
|
||||||
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
|
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
|
||||||
|
|
||||||
|
@ -18,12 +18,13 @@ require (
|
||||||
github.com/frankban/quicktest v1.14.0 // indirect
|
github.com/frankban/quicktest v1.14.0 // indirect
|
||||||
github.com/getsentry/sentry-go v0.12.0
|
github.com/getsentry/sentry-go v0.12.0
|
||||||
github.com/gologme/log v1.3.0
|
github.com/gologme/log v1.3.0
|
||||||
|
github.com/google/go-cmp v0.5.6
|
||||||
|
github.com/google/uuid v1.2.0
|
||||||
github.com/gorilla/mux v1.8.0
|
github.com/gorilla/mux v1.8.0
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/h2non/filetype v1.1.3 // indirect
|
github.com/h2non/filetype v1.1.3 // indirect
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
||||||
github.com/klauspost/compress v1.14.2 // indirect
|
|
||||||
github.com/lib/pq v1.10.4
|
github.com/lib/pq v1.10.4
|
||||||
github.com/libp2p/go-libp2p v0.13.0
|
github.com/libp2p/go-libp2p v0.13.0
|
||||||
github.com/libp2p/go-libp2p-circuit v0.4.0
|
github.com/libp2p/go-libp2p-circuit v0.4.0
|
||||||
|
@ -39,12 +40,12 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
|
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.10
|
github.com/mattn/go-sqlite3 v1.14.10
|
||||||
github.com/morikuni/aec v1.0.0 // indirect
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/nats-io/nats-server/v2 v2.3.2
|
github.com/nats-io/nats-server/v2 v2.7.3
|
||||||
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
|
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
|
@ -61,11 +62,11 @@ require (
|
||||||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||||
github.com/yggdrasil-network/yggdrasil-go v0.4.2
|
github.com/yggdrasil-network/yggdrasil-go v0.4.2
|
||||||
go.uber.org/atomic v1.9.0
|
go.uber.org/atomic v1.9.0
|
||||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a
|
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
|
||||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||||
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
|
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
|
||||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
|
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
|
|
26
go.sum
26
go.sum
|
@ -480,7 +480,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
|
||||||
github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U=
|
github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U=
|
||||||
github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo=
|
github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo=
|
||||||
github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4=
|
github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4=
|
||||||
|
@ -712,9 +711,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
|
||||||
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4=
|
||||||
github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw=
|
github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||||
github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
|
||||||
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
@ -983,8 +981,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
|
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
|
@ -1029,8 +1027,8 @@ github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7
|
||||||
github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g=
|
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g=
|
||||||
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ=
|
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ=
|
||||||
github.com/minio/highwayhash v1.0.1 h1:dZ6IIu8Z14VlC0VpfKofAhCy74wu/Qb5gcn52yWoz/0=
|
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
|
||||||
github.com/minio/highwayhash v1.0.1/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
|
github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
|
||||||
github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
||||||
github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
||||||
github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U=
|
||||||
|
@ -1132,8 +1130,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
|
||||||
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
|
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
|
||||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||||
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8=
|
github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 h1:RJrc+z35RHZlrjR6UBt9UmVRAlFh4SgYyEA0YpQdPHM=
|
||||||
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8=
|
github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740/go.mod h1:eJUrA5gm0ch6sJTEv85xmXIgQWsB0OyjkTsKXvlHbYc=
|
||||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
|
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
|
||||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
|
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
|
||||||
|
@ -1510,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
|
||||||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
|
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
|
||||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
@ -1737,8 +1735,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
|
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
|
||||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||||
|
|
|
@ -10,6 +10,7 @@ const (
|
||||||
FederationEventCacheName = "federation_event"
|
FederationEventCacheName = "federation_event"
|
||||||
FederationEventCacheMaxEntries = 256
|
FederationEventCacheMaxEntries = 256
|
||||||
FederationEventCacheMutable = true // to allow use of Unset only
|
FederationEventCacheMutable = true // to allow use of Unset only
|
||||||
|
FederationEventCacheMaxAge = CacheNoMaxAge
|
||||||
)
|
)
|
||||||
|
|
||||||
// FederationCache contains the subset of functions needed for
|
// FederationCache contains the subset of functions needed for
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package caching
|
package caching
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +18,7 @@ const (
|
||||||
RoomInfoCacheName = "roominfo"
|
RoomInfoCacheName = "roominfo"
|
||||||
RoomInfoCacheMaxEntries = 1024
|
RoomInfoCacheMaxEntries = 1024
|
||||||
RoomInfoCacheMutable = true
|
RoomInfoCacheMutable = true
|
||||||
|
RoomInfoCacheMaxAge = time.Minute * 5
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomInfosCache contains the subset of functions needed for
|
// RoomInfosCache contains the subset of functions needed for
|
||||||
|
|
|
@ -10,6 +10,7 @@ const (
|
||||||
RoomServerRoomIDsCacheName = "roomserver_room_ids"
|
RoomServerRoomIDsCacheName = "roomserver_room_ids"
|
||||||
RoomServerRoomIDsCacheMaxEntries = 1024
|
RoomServerRoomIDsCacheMaxEntries = 1024
|
||||||
RoomServerRoomIDsCacheMutable = false
|
RoomServerRoomIDsCacheMutable = false
|
||||||
|
RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge
|
||||||
)
|
)
|
||||||
|
|
||||||
type RoomServerCaches interface {
|
type RoomServerCaches interface {
|
||||||
|
|
|
@ -6,6 +6,7 @@ const (
|
||||||
RoomVersionCacheName = "room_versions"
|
RoomVersionCacheName = "room_versions"
|
||||||
RoomVersionCacheMaxEntries = 1024
|
RoomVersionCacheMaxEntries = 1024
|
||||||
RoomVersionCacheMutable = false
|
RoomVersionCacheMutable = false
|
||||||
|
RoomVersionCacheMaxAge = CacheNoMaxAge
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomVersionsCache contains the subset of functions needed for
|
// RoomVersionsCache contains the subset of functions needed for
|
||||||
|
|
|
@ -10,6 +10,7 @@ const (
|
||||||
ServerKeyCacheName = "server_key"
|
ServerKeyCacheName = "server_key"
|
||||||
ServerKeyCacheMaxEntries = 4096
|
ServerKeyCacheMaxEntries = 4096
|
||||||
ServerKeyCacheMutable = true
|
ServerKeyCacheMutable = true
|
||||||
|
ServerKeyCacheMaxAge = CacheNoMaxAge
|
||||||
)
|
)
|
||||||
|
|
||||||
// ServerKeyCache contains the subset of functions needed for
|
// ServerKeyCache contains the subset of functions needed for
|
||||||
|
|
33
internal/caching/cache_space_rooms.go
Normal file
33
internal/caching/cache_space_rooms.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package caching
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SpaceSummaryRoomsCacheName = "space_summary_rooms"
|
||||||
|
SpaceSummaryRoomsCacheMaxEntries = 100
|
||||||
|
SpaceSummaryRoomsCacheMutable = true
|
||||||
|
SpaceSummaryRoomsCacheMaxAge = time.Minute * 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type SpaceSummaryRoomsCache interface {
|
||||||
|
GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool)
|
||||||
|
StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) {
|
||||||
|
val, found := c.SpaceSummaryRooms.Get(roomID)
|
||||||
|
if found && val != nil {
|
||||||
|
if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok {
|
||||||
|
return resp, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) {
|
||||||
|
c.SpaceSummaryRooms.Set(roomID, r)
|
||||||
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
package caching
|
package caching
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
// Caches contains a set of references to caches. They may be
|
// Caches contains a set of references to caches. They may be
|
||||||
// different implementations as long as they satisfy the Cache
|
// different implementations as long as they satisfy the Cache
|
||||||
// interface.
|
// interface.
|
||||||
|
@ -10,6 +12,7 @@ type Caches struct {
|
||||||
RoomServerRoomIDs Cache // RoomServerNIDsCache
|
RoomServerRoomIDs Cache // RoomServerNIDsCache
|
||||||
RoomInfos Cache // RoomInfoCache
|
RoomInfos Cache // RoomInfoCache
|
||||||
FederationEvents Cache // FederationEventsCache
|
FederationEvents Cache // FederationEventsCache
|
||||||
|
SpaceSummaryRooms Cache // SpaceSummaryRoomsCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache is the interface that an implementation must satisfy.
|
// Cache is the interface that an implementation must satisfy.
|
||||||
|
@ -18,3 +21,5 @@ type Cache interface {
|
||||||
Set(key string, value interface{})
|
Set(key string, value interface{})
|
||||||
Unset(key string)
|
Unset(key string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CacheNoMaxAge = time.Duration(0)
|
||||||
|
|
|
@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
RoomVersionCacheName,
|
RoomVersionCacheName,
|
||||||
RoomVersionCacheMutable,
|
RoomVersionCacheMutable,
|
||||||
RoomVersionCacheMaxEntries,
|
RoomVersionCacheMaxEntries,
|
||||||
|
RoomVersionCacheMaxAge,
|
||||||
enablePrometheus,
|
enablePrometheus,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
ServerKeyCacheName,
|
ServerKeyCacheName,
|
||||||
ServerKeyCacheMutable,
|
ServerKeyCacheMutable,
|
||||||
ServerKeyCacheMaxEntries,
|
ServerKeyCacheMaxEntries,
|
||||||
|
ServerKeyCacheMaxAge,
|
||||||
enablePrometheus,
|
enablePrometheus,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
RoomServerRoomIDsCacheName,
|
RoomServerRoomIDsCacheName,
|
||||||
RoomServerRoomIDsCacheMutable,
|
RoomServerRoomIDsCacheMutable,
|
||||||
RoomServerRoomIDsCacheMaxEntries,
|
RoomServerRoomIDsCacheMaxEntries,
|
||||||
|
RoomServerRoomIDsCacheMaxAge,
|
||||||
enablePrometheus,
|
enablePrometheus,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
RoomInfoCacheName,
|
RoomInfoCacheName,
|
||||||
RoomInfoCacheMutable,
|
RoomInfoCacheMutable,
|
||||||
RoomInfoCacheMaxEntries,
|
RoomInfoCacheMaxEntries,
|
||||||
|
RoomInfoCacheMaxAge,
|
||||||
enablePrometheus,
|
enablePrometheus,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
FederationEventCacheName,
|
FederationEventCacheName,
|
||||||
FederationEventCacheMutable,
|
FederationEventCacheMutable,
|
||||||
FederationEventCacheMaxEntries,
|
FederationEventCacheMaxEntries,
|
||||||
|
FederationEventCacheMaxAge,
|
||||||
|
enablePrometheus,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
spaceRooms, err := NewInMemoryLRUCachePartition(
|
||||||
|
SpaceSummaryRoomsCacheName,
|
||||||
|
SpaceSummaryRoomsCacheMutable,
|
||||||
|
SpaceSummaryRoomsCacheMaxEntries,
|
||||||
|
SpaceSummaryRoomsCacheMaxAge,
|
||||||
enablePrometheus,
|
enablePrometheus,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
}
|
}
|
||||||
go cacheCleaner(
|
go cacheCleaner(
|
||||||
roomVersions, serverKeys, roomServerRoomIDs,
|
roomVersions, serverKeys, roomServerRoomIDs,
|
||||||
roomInfos, federationEvents,
|
roomInfos, federationEvents, spaceRooms,
|
||||||
)
|
)
|
||||||
return &Caches{
|
return &Caches{
|
||||||
RoomVersions: roomVersions,
|
RoomVersions: roomVersions,
|
||||||
|
@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
||||||
RoomServerRoomIDs: roomServerRoomIDs,
|
RoomServerRoomIDs: roomServerRoomIDs,
|
||||||
RoomInfos: roomInfos,
|
RoomInfos: roomInfos,
|
||||||
FederationEvents: federationEvents,
|
FederationEvents: federationEvents,
|
||||||
|
SpaceSummaryRooms: spaceRooms,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct {
|
||||||
name string
|
name string
|
||||||
mutable bool
|
mutable bool
|
||||||
maxEntries int
|
maxEntries int
|
||||||
|
maxAge time.Duration
|
||||||
lru *lru.Cache
|
lru *lru.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
type inMemoryLRUCacheEntry struct {
|
||||||
|
value interface{}
|
||||||
|
created time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
||||||
var err error
|
var err error
|
||||||
cache := InMemoryLRUCachePartition{
|
cache := InMemoryLRUCachePartition{
|
||||||
name: name,
|
name: name,
|
||||||
mutable: mutable,
|
mutable: mutable,
|
||||||
maxEntries: maxEntries,
|
maxEntries: maxEntries,
|
||||||
|
maxAge: maxAge,
|
||||||
}
|
}
|
||||||
cache.lru, err = lru.New(maxEntries)
|
cache.lru, err = lru.New(maxEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena
|
||||||
|
|
||||||
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
|
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
|
||||||
if !c.mutable {
|
if !c.mutable {
|
||||||
if peek, ok := c.lru.Peek(key); ok && peek != value {
|
if peek, ok := c.lru.Peek(key); ok {
|
||||||
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
|
if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value {
|
||||||
|
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.lru.Add(key, value)
|
c.lru.Add(key, &inMemoryLRUCacheEntry{
|
||||||
|
value: value,
|
||||||
|
created: time.Now(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InMemoryLRUCachePartition) Unset(key string) {
|
func (c *InMemoryLRUCachePartition) Unset(key string) {
|
||||||
|
@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
|
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
|
||||||
return c.lru.Get(key)
|
v, ok := c.lru.Get(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
entry, ok := v.(*inMemoryLRUCacheEntry)
|
||||||
|
switch {
|
||||||
|
case ok && c.maxAge == CacheNoMaxAge:
|
||||||
|
return entry.value, ok // There's no maximum age policy
|
||||||
|
case ok && time.Since(entry.created) < c.maxAge:
|
||||||
|
return entry.value, ok // The value for the key isn't stale
|
||||||
|
default:
|
||||||
|
// Either the key was found and it was stale, or the key
|
||||||
|
// wasn't found at all
|
||||||
|
c.lru.Remove(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
|
||||||
// AccountData represents account data sent from the client API server to the
|
// AccountData represents account data sent from the client API server to the
|
||||||
// sync API server
|
// sync API server
|
||||||
type AccountData struct {
|
type AccountData struct {
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReadMarkerJSON struct {
|
||||||
|
FullyRead string `json:"m.fully_read"`
|
||||||
|
Read string `json:"m.read"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationData contains statistics about notifications, sent from
|
||||||
|
// the Push Server to the Sync API server.
|
||||||
|
type NotificationData struct {
|
||||||
|
// RoomID identifies the scope of the statistics, together with
|
||||||
|
// MXID (which is encoded in the Kafka key).
|
||||||
RoomID string `json:"room_id"`
|
RoomID string `json:"room_id"`
|
||||||
Type string `json:"type"`
|
|
||||||
|
// HighlightCount is the number of unread notifications with the
|
||||||
|
// highlight tweak.
|
||||||
|
UnreadHighlightCount int `json:"unread_highlight_count"`
|
||||||
|
|
||||||
|
// UnreadNotificationCount is the total number of unread
|
||||||
|
// notifications.
|
||||||
|
UnreadNotificationCount int `json:"unread_notification_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProfileResponse is a struct containing all known user profile data
|
// ProfileResponse is a struct containing all known user profile data
|
||||||
|
|
66
internal/pushgateway/client.go
Normal file
66
internal/pushgateway/client.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package pushgateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/opentracing/opentracing-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type httpClient struct {
|
||||||
|
hc *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPClient creates a new Push Gateway client.
|
||||||
|
func NewHTTPClient(disableTLSValidation bool) Client {
|
||||||
|
hc := &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: disableTLSValidation,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &httpClient{hc: hc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hreq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
hresp, err := h.hc.Do(hreq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:errcheck
|
||||||
|
defer hresp.Body.Close()
|
||||||
|
|
||||||
|
if hresp.StatusCode == http.StatusOK {
|
||||||
|
return json.NewDecoder(hresp.Body).Decode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorBody struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
|
||||||
|
return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
|
||||||
|
}
|
62
internal/pushgateway/pushgateway.go
Normal file
62
internal/pushgateway/pushgateway.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package pushgateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Client is how interactions with a Push Gateway is done.
|
||||||
|
type Client interface {
|
||||||
|
// Notify sends a notification to the gateway at the given URL.
|
||||||
|
Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotifyRequest struct {
|
||||||
|
Notification Notification `json:"notification"` // Required
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotifyResponse struct {
|
||||||
|
// Rejected is the list of device push keys that were rejected
|
||||||
|
// during the push. The caller should remove the push keys so they
|
||||||
|
// are not used again.
|
||||||
|
Rejected []string `json:"rejected"` // Required
|
||||||
|
}
|
||||||
|
|
||||||
|
type Notification struct {
|
||||||
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
Counts *Counts `json:"counts,omitempty"`
|
||||||
|
Devices []*Device `json:"devices"` // Required
|
||||||
|
EventID string `json:"event_id,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"` // Deprecated name for EventID.
|
||||||
|
Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
|
||||||
|
Prio Prio `json:"prio,omitempty"`
|
||||||
|
RoomAlias string `json:"room_alias,omitempty"`
|
||||||
|
RoomID string `json:"room_id,omitempty"`
|
||||||
|
RoomName string `json:"room_name,omitempty"`
|
||||||
|
Sender string `json:"sender,omitempty"`
|
||||||
|
SenderDisplayName string `json:"sender_display_name,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
UserIsTarget bool `json:"user_is_target,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Counts struct {
|
||||||
|
MissedCalls int `json:"missed_calls,omitempty"`
|
||||||
|
Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
|
||||||
|
}
|
||||||
|
|
||||||
|
type Device struct {
|
||||||
|
AppID string `json:"app_id"` // Required
|
||||||
|
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||||
|
PushKey string `json:"pushkey"` // Required
|
||||||
|
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||||
|
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Prio string
|
||||||
|
|
||||||
|
const (
|
||||||
|
HighPrio Prio = "high"
|
||||||
|
LowPrio Prio = "low"
|
||||||
|
)
|
102
internal/pushrules/action.go
Normal file
102
internal/pushrules/action.go
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An Action is (part of) an outcome of a rule. There are
|
||||||
|
// (unofficially) terminal actions, and modifier actions.
|
||||||
|
type Action struct {
|
||||||
|
// Kind is the type of action. Has custom encoding in JSON.
|
||||||
|
Kind ActionKind `json:"-"`
|
||||||
|
|
||||||
|
// Tweak is the property to tweak. Has custom encoding in JSON.
|
||||||
|
Tweak TweakKey `json:"-"`
|
||||||
|
|
||||||
|
// Value is some value interpreted according to Kind and Tweak.
|
||||||
|
Value interface{} `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Action) MarshalJSON() ([]byte, error) {
|
||||||
|
if a.Tweak == UnknownTweak && a.Value == nil {
|
||||||
|
return json.Marshal(a.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Kind != SetTweakAction {
|
||||||
|
return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := map[string]interface{}{
|
||||||
|
string(a.Kind): a.Tweak,
|
||||||
|
}
|
||||||
|
if a.Value != nil {
|
||||||
|
m["value"] = a.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Action) UnmarshalJSON(bs []byte) error {
|
||||||
|
if bytes.HasPrefix(bs, []byte("\"")) {
|
||||||
|
return json.Unmarshal(bs, &a.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw struct {
|
||||||
|
SetTweak TweakKey `json:"set_tweak"`
|
||||||
|
Value interface{} `json:"value"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(bs, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if raw.SetTweak == UnknownTweak {
|
||||||
|
return fmt.Errorf("got unknown action JSON: %s", string(bs))
|
||||||
|
}
|
||||||
|
a.Kind = SetTweakAction
|
||||||
|
a.Tweak = raw.SetTweak
|
||||||
|
if raw.Value != nil {
|
||||||
|
a.Value = raw.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActionKind is the primary discriminator for actions.
|
||||||
|
type ActionKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownAction ActionKind = ""
|
||||||
|
|
||||||
|
// NotifyAction indicates the clients should show a notification.
|
||||||
|
NotifyAction ActionKind = "notify"
|
||||||
|
|
||||||
|
// DontNotifyAction indicates the clients should not show a notification.
|
||||||
|
DontNotifyAction ActionKind = "dont_notify"
|
||||||
|
|
||||||
|
// CoalesceAction tells the clients to show a notification, and
|
||||||
|
// tells both servers and clients that multiple events can be
|
||||||
|
// coalesced into a single notification. The behaviour is
|
||||||
|
// implementation-specific.
|
||||||
|
CoalesceAction ActionKind = "coalesce"
|
||||||
|
|
||||||
|
// SetTweakAction uses the Tweak and Value fields to add a
|
||||||
|
// tweak. Multiple SetTweakAction can be provided in a rule,
|
||||||
|
// combined with NotifyAction or CoalesceAction.
|
||||||
|
SetTweakAction ActionKind = "set_tweak"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A TweakKey describes a property to be modified/tweaked for events
|
||||||
|
// that match the rule.
|
||||||
|
type TweakKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownTweak TweakKey = ""
|
||||||
|
|
||||||
|
// SoundTweak describes which sound to play. Using "default" means
|
||||||
|
// "enable sound".
|
||||||
|
SoundTweak TweakKey = "sound"
|
||||||
|
|
||||||
|
// HighlightTweak asks the clients to highlight the conversation.
|
||||||
|
HighlightTweak TweakKey = "highlight"
|
||||||
|
)
|
39
internal/pushrules/action_test.go
Normal file
39
internal/pushrules/action_test.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActionJSON(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Want Action
|
||||||
|
}{
|
||||||
|
{Action{Kind: NotifyAction}},
|
||||||
|
{Action{Kind: DontNotifyAction}},
|
||||||
|
{Action{Kind: CoalesceAction}},
|
||||||
|
{Action{Kind: SetTweakAction}},
|
||||||
|
|
||||||
|
{Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
|
||||||
|
{Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
|
||||||
|
{Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
|
||||||
|
bs, err := json.Marshal(&tst.Want)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
var got Action
|
||||||
|
if err := json.Unmarshal(bs, &got); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||||
|
t.Errorf("+got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
49
internal/pushrules/condition.go
Normal file
49
internal/pushrules/condition.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
// A Condition dictates extra conditions for a matching rules. See
|
||||||
|
// ConditionKind.
|
||||||
|
type Condition struct {
|
||||||
|
// Kind is the primary discriminator for the condition
|
||||||
|
// type. Required.
|
||||||
|
Kind ConditionKind `json:"kind"`
|
||||||
|
|
||||||
|
// Key indicates the dot-separated path of Event fields to
|
||||||
|
// match. Required for EventMatchCondition and
|
||||||
|
// SenderNotificationPermissionCondition.
|
||||||
|
Key string `json:"key,omitempty"`
|
||||||
|
|
||||||
|
// Pattern indicates the value pattern that must match. Required
|
||||||
|
// for EventMatchCondition.
|
||||||
|
Pattern string `json:"pattern,omitempty"`
|
||||||
|
|
||||||
|
// Is indicates the condition that must be fulfilled. Required for
|
||||||
|
// RoomMemberCountCondition.
|
||||||
|
Is string `json:"is,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConditionKind represents a kind of condition.
|
||||||
|
//
|
||||||
|
// SPEC: Unrecognised conditions MUST NOT match any events,
|
||||||
|
// effectively making the push rule disabled.
|
||||||
|
type ConditionKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownCondition ConditionKind = ""
|
||||||
|
|
||||||
|
// EventMatchCondition indicates the condition looks for a key
|
||||||
|
// path and matches a pattern. How paths that don't reference a
|
||||||
|
// simple value match against rules is implementation-specific.
|
||||||
|
EventMatchCondition ConditionKind = "event_match"
|
||||||
|
|
||||||
|
// ContainsDisplayNameCondition indicates the current user's
|
||||||
|
// display name must be found in the content body.
|
||||||
|
ContainsDisplayNameCondition ConditionKind = "contains_display_name"
|
||||||
|
|
||||||
|
// RoomMemberCountCondition matches a simple arithmetic comparison
|
||||||
|
// against the total number of members in a room.
|
||||||
|
RoomMemberCountCondition ConditionKind = "room_member_count"
|
||||||
|
|
||||||
|
// SenderNotificationPermissionCondition compares power level for
|
||||||
|
// the sender in the event's room.
|
||||||
|
SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
|
||||||
|
)
|
23
internal/pushrules/default.go
Normal file
23
internal/pushrules/default.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultAccountRuleSets is the complete set of default push rules
|
||||||
|
// for an account.
|
||||||
|
func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
|
||||||
|
return &AccountRuleSets{
|
||||||
|
Global: *DefaultGlobalRuleSet(localpart, serverName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultGlobalRuleSet returns the default ruleset for a given (fully
|
||||||
|
// qualified) MXID.
|
||||||
|
func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
|
||||||
|
return &RuleSet{
|
||||||
|
Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
|
||||||
|
Content: defaultContentRules(localpart),
|
||||||
|
Underride: defaultUnderrideRules,
|
||||||
|
}
|
||||||
|
}
|
33
internal/pushrules/default_content.go
Normal file
33
internal/pushrules/default_content.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
func defaultContentRules(localpart string) []*Rule {
|
||||||
|
return []*Rule{
|
||||||
|
mRuleContainsUserNameDefinition(localpart),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
MRuleContainsUserName = ".m.rule.contains_user_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mRuleContainsUserNameDefinition(localpart string) *Rule {
|
||||||
|
return &Rule{
|
||||||
|
RuleID: MRuleContainsUserName,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Pattern: localpart,
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: SoundTweak,
|
||||||
|
Value: "default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
165
internal/pushrules/default_override.go
Normal file
165
internal/pushrules/default_override.go
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
func defaultOverrideRules(userID string) []*Rule {
|
||||||
|
return []*Rule{
|
||||||
|
&mRuleMasterDefinition,
|
||||||
|
&mRuleSuppressNoticesDefinition,
|
||||||
|
mRuleInviteForMeDefinition(userID),
|
||||||
|
&mRuleMemberEventDefinition,
|
||||||
|
&mRuleContainsDisplayNameDefinition,
|
||||||
|
&mRuleTombstoneDefinition,
|
||||||
|
&mRuleRoomNotifDefinition,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
MRuleMaster = ".m.rule.master"
|
||||||
|
MRuleSuppressNotices = ".m.rule.suppress_notices"
|
||||||
|
MRuleInviteForMe = ".m.rule.invite_for_me"
|
||||||
|
MRuleMemberEvent = ".m.rule.member_event"
|
||||||
|
MRuleContainsDisplayName = ".m.rule.contains_display_name"
|
||||||
|
MRuleTombstone = ".m.rule.tombstone"
|
||||||
|
MRuleRoomNotif = ".m.rule.roomnotif"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mRuleMasterDefinition = Rule{
|
||||||
|
RuleID: MRuleMaster,
|
||||||
|
Default: true,
|
||||||
|
Enabled: false,
|
||||||
|
Conditions: []*Condition{},
|
||||||
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
|
}
|
||||||
|
mRuleSuppressNoticesDefinition = Rule{
|
||||||
|
RuleID: MRuleSuppressNotices,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "content.msgtype",
|
||||||
|
Pattern: "m.notice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
|
}
|
||||||
|
mRuleMemberEventDefinition = Rule{
|
||||||
|
RuleID: MRuleMemberEvent,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.member",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
|
}
|
||||||
|
mRuleContainsDisplayNameDefinition = Rule{
|
||||||
|
RuleID: MRuleContainsDisplayName,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: SoundTweak,
|
||||||
|
Value: "default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mRuleTombstoneDefinition = Rule{
|
||||||
|
RuleID: MRuleTombstone,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.tombstone",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "state_key",
|
||||||
|
Pattern: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mRuleRoomNotifDefinition = Rule{
|
||||||
|
RuleID: MRuleRoomNotif,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "content.body",
|
||||||
|
Pattern: "@room",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: SenderNotificationPermissionCondition,
|
||||||
|
Key: "room",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func mRuleInviteForMeDefinition(userID string) *Rule {
|
||||||
|
return &Rule{
|
||||||
|
RuleID: MRuleInviteForMe,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.member",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "content.membership",
|
||||||
|
Pattern: "invite",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "state_key",
|
||||||
|
Pattern: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: SoundTweak,
|
||||||
|
Value: "default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
119
internal/pushrules/default_underride.go
Normal file
119
internal/pushrules/default_underride.go
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
const (
|
||||||
|
MRuleCall = ".m.rule.call"
|
||||||
|
MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
|
||||||
|
MRuleRoomOneToOne = ".m.rule.room_one_to_one"
|
||||||
|
MRuleMessage = ".m.rule.message"
|
||||||
|
MRuleEncrypted = ".m.rule.encrypted"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultUnderrideRules = []*Rule{
|
||||||
|
&mRuleCallDefinition,
|
||||||
|
&mRuleEncryptedRoomOneToOneDefinition,
|
||||||
|
&mRuleRoomOneToOneDefinition,
|
||||||
|
&mRuleMessageDefinition,
|
||||||
|
&mRuleEncryptedDefinition,
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
mRuleCallDefinition = Rule{
|
||||||
|
RuleID: MRuleCall,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.call.invite",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: SoundTweak,
|
||||||
|
Value: "ring",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mRuleEncryptedRoomOneToOneDefinition = Rule{
|
||||||
|
RuleID: MRuleEncryptedRoomOneToOne,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: RoomMemberCountCondition,
|
||||||
|
Is: "2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.encrypted",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mRuleRoomOneToOneDefinition = Rule{
|
||||||
|
RuleID: MRuleRoomOneToOne,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: RoomMemberCountCondition,
|
||||||
|
Is: "2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.message",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{
|
||||||
|
{Kind: NotifyAction},
|
||||||
|
{
|
||||||
|
Kind: SetTweakAction,
|
||||||
|
Tweak: HighlightTweak,
|
||||||
|
Value: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mRuleMessageDefinition = Rule{
|
||||||
|
RuleID: MRuleMessage,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.message",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{{Kind: NotifyAction}},
|
||||||
|
}
|
||||||
|
mRuleEncryptedDefinition = Rule{
|
||||||
|
RuleID: MRuleEncrypted,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: "m.room.encrypted",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{{Kind: NotifyAction}},
|
||||||
|
}
|
||||||
|
)
|
165
internal/pushrules/evaluate.go
Normal file
165
internal/pushrules/evaluate.go
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A RuleSetEvaluator encapsulates context to evaluate an event
|
||||||
|
// against a rule set.
|
||||||
|
type RuleSetEvaluator struct {
|
||||||
|
ec EvaluationContext
|
||||||
|
ruleSet []kindAndRules
|
||||||
|
}
|
||||||
|
|
||||||
|
// An EvaluationContext gives a RuleSetEvaluator access to the
|
||||||
|
// environment, for rules that require that.
|
||||||
|
type EvaluationContext interface {
|
||||||
|
// UserDisplayName returns the current user's display name.
|
||||||
|
UserDisplayName() string
|
||||||
|
|
||||||
|
// RoomMemberCount returns the number of members in the room of
|
||||||
|
// the current event.
|
||||||
|
RoomMemberCount() (int, error)
|
||||||
|
|
||||||
|
// HasPowerLevel returns whether the user has at least the given
|
||||||
|
// power in the room of the current event.
|
||||||
|
HasPowerLevel(userID, levelKey string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A kindAndRules is just here to simplify iteration of the (ordered)
|
||||||
|
// kinds of rules.
|
||||||
|
type kindAndRules struct {
|
||||||
|
Kind Kind
|
||||||
|
Rules []*Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRuleSetEvaluator creates a new evaluator for the given rule set.
|
||||||
|
func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
|
||||||
|
return &RuleSetEvaluator{
|
||||||
|
ec: ec,
|
||||||
|
ruleSet: []kindAndRules{
|
||||||
|
{OverrideKind, ruleSet.Override},
|
||||||
|
{ContentKind, ruleSet.Content},
|
||||||
|
{RoomKind, ruleSet.Room},
|
||||||
|
{SenderKind, ruleSet.Sender},
|
||||||
|
{UnderrideKind, ruleSet.Underride},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchEvent returns the first matching rule. Returns nil if there
|
||||||
|
// was no match rule.
|
||||||
|
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
|
||||||
|
// TODO: server-default rules have lower priority than user rules,
|
||||||
|
// but they are stored together with the user rules. It's a bit
|
||||||
|
// unclear what the specification (11.14.1.4 Predefined rules)
|
||||||
|
// means the ordering should be.
|
||||||
|
//
|
||||||
|
// The most reasonable interpretation is that default overrides
|
||||||
|
// still have lower priority than user content rules, so we
|
||||||
|
// iterate twice.
|
||||||
|
for _, rsat := range rse.ruleSet {
|
||||||
|
for _, defRules := range []bool{false, true} {
|
||||||
|
for _, rule := range rsat.Rules {
|
||||||
|
if rule.Default != defRules {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No matching rule.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||||
|
if !rule.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case OverrideKind, UnderrideKind:
|
||||||
|
for _, cond := range rule.Conditions {
|
||||||
|
ok, err := conditionMatches(cond, event, ec)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
|
||||||
|
case ContentKind:
|
||||||
|
// TODO: "These configure behaviour for (unencrypted) messages
|
||||||
|
// that match certain patterns." - Does that mean "content.body"?
|
||||||
|
return patternMatches("content.body", rule.Pattern, event)
|
||||||
|
|
||||||
|
case RoomKind:
|
||||||
|
return rule.RuleID == event.RoomID(), nil
|
||||||
|
|
||||||
|
case SenderKind:
|
||||||
|
return rule.RuleID == event.Sender(), nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||||
|
switch cond.Kind {
|
||||||
|
case EventMatchCondition:
|
||||||
|
return patternMatches(cond.Key, cond.Pattern, event)
|
||||||
|
|
||||||
|
case ContainsDisplayNameCondition:
|
||||||
|
return patternMatches("content.body", ec.UserDisplayName(), event)
|
||||||
|
|
||||||
|
case RoomMemberCountCondition:
|
||||||
|
cmp, err := parseRoomMemberCountCondition(cond.Is)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("parsing room_member_count condition: %w", err)
|
||||||
|
}
|
||||||
|
n, err := ec.RoomMemberCount()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("RoomMemberCount failed: %w", err)
|
||||||
|
}
|
||||||
|
return cmp(n), nil
|
||||||
|
|
||||||
|
case SenderNotificationPermissionCondition:
|
||||||
|
return ec.HasPowerLevel(event.Sender(), cond.Key)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
|
||||||
|
re, err := globToRegexp(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventMap map[string]interface{}
|
||||||
|
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
|
||||||
|
return false, fmt.Errorf("parsing event: %w", err)
|
||||||
|
}
|
||||||
|
v, err := lookupMapPath(strings.Split(key, "."), eventMap)
|
||||||
|
if err != nil {
|
||||||
|
// An unknown path is a benign error that shouldn't stop rule
|
||||||
|
// processing. It's just a non-match.
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return re.MatchString(fmt.Sprint(v)), nil
|
||||||
|
}
|
189
internal/pushrules/evaluate_test.go
Normal file
189
internal/pushrules/evaluate_test.go
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
||||||
|
ev := mustEventFromJSON(t, `{}`)
|
||||||
|
defaultEnabled := &Rule{
|
||||||
|
RuleID: ".default.enabled",
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
userEnabled := &Rule{
|
||||||
|
RuleID: ".user.enabled",
|
||||||
|
Default: false,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
userEnabled2 := &Rule{
|
||||||
|
RuleID: ".user.enabled.2",
|
||||||
|
Default: false,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
RuleSet RuleSet
|
||||||
|
Want *Rule
|
||||||
|
}{
|
||||||
|
{"empty", RuleSet{}, nil},
|
||||||
|
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
|
||||||
|
{"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
|
||||||
|
{"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
|
||||||
|
{"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
|
||||||
|
{"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
|
||||||
|
{"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
|
||||||
|
{"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
|
||||||
|
got, err := rse.MatchEvent(ev)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MatchEvent failed: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||||
|
t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuleMatches(t *testing.T) {
|
||||||
|
emptyRule := Rule{Enabled: true}
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Kind Kind
|
||||||
|
Rule Rule
|
||||||
|
EventJSON string
|
||||||
|
Want bool
|
||||||
|
}{
|
||||||
|
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
|
||||||
|
{"emptyContent", ContentKind, emptyRule, `{}`, false},
|
||||||
|
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
|
||||||
|
{"emptySender", SenderKind, emptyRule, `{}`, true},
|
||||||
|
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
|
||||||
|
|
||||||
|
{"disabled", OverrideKind, Rule{}, `{}`, false},
|
||||||
|
|
||||||
|
{"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
|
||||||
|
{"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||||
|
|
||||||
|
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
|
||||||
|
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||||
|
|
||||||
|
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
|
||||||
|
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
|
||||||
|
|
||||||
|
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
|
||||||
|
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
|
||||||
|
|
||||||
|
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
|
||||||
|
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ruleMatches failed: %v", err)
|
||||||
|
}
|
||||||
|
if got != tst.Want {
|
||||||
|
t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConditionMatches(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Cond Condition
|
||||||
|
EventJSON string
|
||||||
|
Want bool
|
||||||
|
}{
|
||||||
|
{"empty", Condition{}, `{}`, false},
|
||||||
|
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
|
||||||
|
|
||||||
|
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
|
||||||
|
|
||||||
|
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
|
||||||
|
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
|
||||||
|
|
||||||
|
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
|
||||||
|
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
|
||||||
|
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
|
||||||
|
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
|
||||||
|
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
|
||||||
|
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
|
||||||
|
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
|
||||||
|
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
|
||||||
|
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
|
||||||
|
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
|
||||||
|
|
||||||
|
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
|
||||||
|
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conditionMatches failed: %v", err)
|
||||||
|
}
|
||||||
|
if got != tst.Want {
|
||||||
|
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeEvaluationContext struct{}
|
||||||
|
|
||||||
|
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
|
||||||
|
func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
|
||||||
|
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
|
||||||
|
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternMatches(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Key string
|
||||||
|
Pattern string
|
||||||
|
EventJSON string
|
||||||
|
Want bool
|
||||||
|
}{
|
||||||
|
{"empty", "", "", `{}`, false},
|
||||||
|
|
||||||
|
// Note that an empty pattern contains no wildcard characters,
|
||||||
|
// which implicitly means "*".
|
||||||
|
{"patternEmpty", "content", "", `{"content":{}}`, true},
|
||||||
|
|
||||||
|
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
|
||||||
|
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
|
||||||
|
{"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
|
||||||
|
{"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
|
||||||
|
{"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("patternMatches failed: %v", err)
|
||||||
|
}
|
||||||
|
if got != tst.Want {
|
||||||
|
t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return ev
|
||||||
|
}
|
71
internal/pushrules/pushrules.go
Normal file
71
internal/pushrules/pushrules.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
// An AccountRuleSets carries the rule sets associated with an
|
||||||
|
// account.
|
||||||
|
type AccountRuleSets struct {
|
||||||
|
Global RuleSet `json:"global"` // Required
|
||||||
|
}
|
||||||
|
|
||||||
|
// A RuleSet contains all the various push rules for an
|
||||||
|
// account. Listed in decreasing order of priority.
|
||||||
|
type RuleSet struct {
|
||||||
|
Override []*Rule `json:"override,omitempty"`
|
||||||
|
Content []*Rule `json:"content,omitempty"`
|
||||||
|
Room []*Rule `json:"room,omitempty"`
|
||||||
|
Sender []*Rule `json:"sender,omitempty"`
|
||||||
|
Underride []*Rule `json:"underride,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Rule contains matchers, conditions and final actions. While
|
||||||
|
// evaluating, at most one rule is considered matching.
|
||||||
|
//
|
||||||
|
// Kind and scope are part of the push rules request/responses, but
|
||||||
|
// not of the core data model.
|
||||||
|
type Rule struct {
|
||||||
|
// RuleID is either a free identifier, or the sender's MXID for
|
||||||
|
// SenderKind. Required.
|
||||||
|
RuleID string `json:"rule_id"`
|
||||||
|
|
||||||
|
// Default indicates whether this is a server-defined default, or
|
||||||
|
// a user-provided rule. Required.
|
||||||
|
//
|
||||||
|
// The server-default rules have the lowest priority.
|
||||||
|
Default bool `json:"default"`
|
||||||
|
|
||||||
|
// Enabled allows the user to disable rules while keeping them
|
||||||
|
// around. Required.
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
|
// Actions describe the desired outcome, should the rule
|
||||||
|
// match. Required.
|
||||||
|
Actions []*Action `json:"actions"`
|
||||||
|
|
||||||
|
// Conditions provide the rule's conditions for OverrideKind and
|
||||||
|
// UnderrideKind. Not allowed for other kinds.
|
||||||
|
Conditions []*Condition `json:"conditions"`
|
||||||
|
|
||||||
|
// Pattern is the body pattern to match for ContentKind. Required
|
||||||
|
// for that kind. The interpretation is the same as that of
|
||||||
|
// Condition.Pattern.
|
||||||
|
Pattern string `json:"pattern"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scope only has one valid value. See also AccountRuleSets.
|
||||||
|
type Scope string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownScope Scope = ""
|
||||||
|
GlobalScope Scope = "global"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Kind is the type of push rule. See also RuleSet.
|
||||||
|
type Kind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownKind Kind = ""
|
||||||
|
OverrideKind Kind = "override"
|
||||||
|
ContentKind Kind = "content"
|
||||||
|
RoomKind Kind = "room"
|
||||||
|
SenderKind Kind = "sender"
|
||||||
|
UnderrideKind Kind = "underride"
|
||||||
|
)
|
125
internal/pushrules/util.go
Normal file
125
internal/pushrules/util.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ActionsToTweaks converts a list of actions into a primary action
|
||||||
|
// kind and a tweaks map. Returns a nil map if it would have been
|
||||||
|
// empty.
|
||||||
|
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
|
||||||
|
var kind ActionKind
|
||||||
|
tweaks := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, a := range as {
|
||||||
|
if a.Kind == SetTweakAction {
|
||||||
|
tweaks[string(a.Tweak)] = a.Value
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if kind != UnknownAction {
|
||||||
|
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
|
||||||
|
}
|
||||||
|
kind = a.Kind
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tweaks) == 0 {
|
||||||
|
tweaks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return kind, tweaks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoolTweakOr returns the named tweak as a boolean, and returns `def`
|
||||||
|
// on failure.
|
||||||
|
func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
|
||||||
|
v, ok := tweaks[string(key)]
|
||||||
|
if !ok {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
b, ok := v.(bool)
|
||||||
|
if !ok {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
|
||||||
|
func globToRegexp(pattern string) (*regexp.Regexp, error) {
|
||||||
|
// TODO: It's unclear which glob characters are supported. The only
|
||||||
|
// place this is discussed is for the unrelated "m.policy.rule.*"
|
||||||
|
// events. Assuming, the same: /[*?]/
|
||||||
|
if !strings.ContainsAny(pattern, "*?") {
|
||||||
|
pattern = "*" + pattern + "*"
|
||||||
|
}
|
||||||
|
|
||||||
|
// The defined syntax doesn't allow escaping the glob wildcard
|
||||||
|
// characters, which makes this a straight-forward
|
||||||
|
// replace-after-quote.
|
||||||
|
pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
|
||||||
|
pattern = strings.Replace(pattern, "*", ".*", -1)
|
||||||
|
pattern = strings.Replace(pattern, "?", ".", -1)
|
||||||
|
return regexp.Compile("^(" + pattern + ")$")
|
||||||
|
}
|
||||||
|
|
||||||
|
// globNonMetaRegexp are the characters that are not considered glob
|
||||||
|
// meta-characters (i.e. may need escaping).
|
||||||
|
var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
|
||||||
|
|
||||||
|
// lookupMapPath traverses a hierarchical map structure, like the one
|
||||||
|
// produced by json.Unmarshal, to return the leaf value. Traversing
|
||||||
|
// arrays/slices is not supported, only objects/maps.
|
||||||
|
func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty path")
|
||||||
|
}
|
||||||
|
|
||||||
|
var v interface{} = m
|
||||||
|
for i, key := range path {
|
||||||
|
m, ok := v.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
|
||||||
|
// into a function that checks if the argument to it fulfils the
|
||||||
|
// condition.
|
||||||
|
func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
|
||||||
|
var b int
|
||||||
|
var cmp = func(a int) bool { return a == b }
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(s, "<="):
|
||||||
|
cmp = func(a int) bool { return a <= b }
|
||||||
|
s = s[2:]
|
||||||
|
case strings.HasPrefix(s, ">="):
|
||||||
|
cmp = func(a int) bool { return a >= b }
|
||||||
|
s = s[2:]
|
||||||
|
case strings.HasPrefix(s, "<"):
|
||||||
|
cmp = func(a int) bool { return a < b }
|
||||||
|
s = s[1:]
|
||||||
|
case strings.HasPrefix(s, ">"):
|
||||||
|
cmp = func(a int) bool { return a > b }
|
||||||
|
s = s[1:]
|
||||||
|
case strings.HasPrefix(s, "=="):
|
||||||
|
// Same cmp as the default.
|
||||||
|
s = s[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b = int(v)
|
||||||
|
return cmp, nil
|
||||||
|
}
|
169
internal/pushrules/util_test.go
Normal file
169
internal/pushrules/util_test.go
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActionsToTweaks(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Input []*Action
|
||||||
|
WantKind ActionKind
|
||||||
|
WantTweaks map[string]interface{}
|
||||||
|
}{
|
||||||
|
{"empty", nil, UnknownAction, nil},
|
||||||
|
{"zero", []*Action{{}}, UnknownAction, nil},
|
||||||
|
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
|
||||||
|
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
|
||||||
|
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
|
||||||
|
{
|
||||||
|
"all",
|
||||||
|
[]*Action{
|
||||||
|
{Kind: CoalesceAction},
|
||||||
|
{Kind: SetTweakAction, Tweak: HighlightTweak},
|
||||||
|
{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
|
||||||
|
},
|
||||||
|
CoalesceAction,
|
||||||
|
map[string]interface{}{"highlight": nil, "sound": "default"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ActionsToTweaks failed: %v", err)
|
||||||
|
}
|
||||||
|
if gotKind != tst.WantKind {
|
||||||
|
t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
|
||||||
|
t.Errorf("tweaks: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBoolTweakOr(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Input map[string]interface{}
|
||||||
|
Def bool
|
||||||
|
Want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false, false},
|
||||||
|
{"nilValue", map[string]interface{}{"highlight": nil}, true, true},
|
||||||
|
{"false", map[string]interface{}{"highlight": false}, true, false},
|
||||||
|
{"true", map[string]interface{}{"highlight": true}, false, true},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
|
||||||
|
if got != tst.Want {
|
||||||
|
t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobToRegexp(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Input string
|
||||||
|
Want string
|
||||||
|
}{
|
||||||
|
{"", "^(.*.*)$"},
|
||||||
|
{"a", "^(.*a.*)$"},
|
||||||
|
{"a.b", "^(.*a\\.b.*)$"},
|
||||||
|
{"a?b", "^(a.b)$"},
|
||||||
|
{"a*b*", "^(a.*b.*)$"},
|
||||||
|
{"a*b?", "^(a.*b.)$"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Want, func(t *testing.T) {
|
||||||
|
got, err := globToRegexp(tst.Input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("globToRegexp failed: %v", err)
|
||||||
|
}
|
||||||
|
if got.String() != tst.Want {
|
||||||
|
t.Errorf("got %v, want %v", got.String(), tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupMapPath(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Path []string
|
||||||
|
Root map[string]interface{}
|
||||||
|
Want interface{}
|
||||||
|
}{
|
||||||
|
{[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
|
||||||
|
{[]string{"a"}, map[string]interface{}{"a": 42}, 42},
|
||||||
|
{[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
|
||||||
|
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("lookupMapPath failed: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||||
|
t.Errorf("+got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupMapPathInvalid(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Path []string
|
||||||
|
Root map[string]interface{}
|
||||||
|
}{
|
||||||
|
{nil, nil},
|
||||||
|
{[]string{"a"}, nil},
|
||||||
|
{[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
|
||||||
|
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRoomMemberCountCondition(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Input string
|
||||||
|
WantTrue []int
|
||||||
|
WantFalse []int
|
||||||
|
}{
|
||||||
|
{"1", []int{1}, []int{0, 2}},
|
||||||
|
{"==1", []int{1}, []int{0, 2}},
|
||||||
|
{"<1", []int{0}, []int{1, 2}},
|
||||||
|
{"<=1", []int{0, 1}, []int{2}},
|
||||||
|
{">1", []int{2}, []int{0, 1}},
|
||||||
|
{">=42", []int{42, 43}, []int{41}},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Input, func(t *testing.T) {
|
||||||
|
got, err := parseRoomMemberCountCondition(tst.Input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, v := range tst.WantTrue {
|
||||||
|
if !got(v) {
|
||||||
|
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, v := range tst.WantFalse {
|
||||||
|
if got(v) {
|
||||||
|
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
85
internal/pushrules/validate.go
Normal file
85
internal/pushrules/validate.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateRule checks the rule for errors. These follow from Sytests
|
||||||
|
// and the specification.
|
||||||
|
func ValidateRule(kind Kind, rule *Rule) []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if !validRuleIDRE.MatchString(rule.RuleID) {
|
||||||
|
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rule.Actions) == 0 {
|
||||||
|
errs = append(errs, fmt.Errorf("missing actions"))
|
||||||
|
}
|
||||||
|
for _, action := range rule.Actions {
|
||||||
|
errs = append(errs, validateAction(action)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cond := range rule.Conditions {
|
||||||
|
errs = append(errs, validateCondition(cond)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case OverrideKind, UnderrideKind:
|
||||||
|
// The empty list is allowed, but for JSON-encoding reasons,
|
||||||
|
// it must not be nil.
|
||||||
|
if rule.Conditions == nil {
|
||||||
|
errs = append(errs, fmt.Errorf("missing rule conditions"))
|
||||||
|
}
|
||||||
|
|
||||||
|
case ContentKind:
|
||||||
|
if rule.Pattern == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
||||||
|
}
|
||||||
|
|
||||||
|
case RoomKind, SenderKind:
|
||||||
|
// Do nothing.
|
||||||
|
|
||||||
|
default:
|
||||||
|
errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// validRuleIDRE is a regexp for valid IDs.
|
||||||
|
//
|
||||||
|
// TODO: the specification doesn't seem to say what the rule ID syntax
|
||||||
|
// is. A Sytest fails if it contains a backslash.
|
||||||
|
var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
|
||||||
|
|
||||||
|
// validateAction returns issues with an Action.
|
||||||
|
func validateAction(action *Action) []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
switch action.Kind {
|
||||||
|
case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
|
||||||
|
// Do nothing.
|
||||||
|
|
||||||
|
default:
|
||||||
|
errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateCondition returns issues with a Condition.
|
||||||
|
func validateCondition(cond *Condition) []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
switch cond.Kind {
|
||||||
|
case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
|
||||||
|
// Do nothing.
|
||||||
|
|
||||||
|
default:
|
||||||
|
errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
163
internal/pushrules/validate_test.go
Normal file
163
internal/pushrules/validate_test.go
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRuleNegatives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Kind Kind
|
||||||
|
Rule Rule
|
||||||
|
WantErrString string
|
||||||
|
}{
|
||||||
|
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
|
||||||
|
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
|
||||||
|
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
|
||||||
|
{"noActions", OverrideKind, Rule{}, "missing actions"},
|
||||||
|
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
|
||||||
|
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
|
||||||
|
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
|
||||||
|
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
|
||||||
|
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||||
|
var foundErr error
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||||
|
foundErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundErr == nil {
|
||||||
|
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRulePositives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Kind Kind
|
||||||
|
Rule Rule
|
||||||
|
WantNoErrString string
|
||||||
|
}{
|
||||||
|
{"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
|
||||||
|
{"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
|
||||||
|
{"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
|
||||||
|
{"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
|
||||||
|
{"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
|
||||||
|
{"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
|
||||||
|
{"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
|
||||||
|
{"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||||
|
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateActionNegatives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Action Action
|
||||||
|
WantErrString string
|
||||||
|
}{
|
||||||
|
{"emptyKind", Action{}, "invalid rule action kind"},
|
||||||
|
{"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := validateAction(&tst.Action)
|
||||||
|
var foundErr error
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||||
|
foundErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundErr == nil {
|
||||||
|
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateActionPositives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Action Action
|
||||||
|
WantNoErrString string
|
||||||
|
}{
|
||||||
|
{"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := validateAction(&tst.Action)
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||||
|
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateConditionNegatives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Condition Condition
|
||||||
|
WantErrString string
|
||||||
|
}{
|
||||||
|
{"emptyKind", Condition{}, "invalid rule condition kind"},
|
||||||
|
{"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := validateCondition(&tst.Condition)
|
||||||
|
var foundErr error
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||||
|
foundErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundErr == nil {
|
||||||
|
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateConditionPositives(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Condition Condition
|
||||||
|
WantNoErrString string
|
||||||
|
}{
|
||||||
|
{"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
errs := validateCondition(&tst.Condition)
|
||||||
|
for _, err := range errs {
|
||||||
|
t.Logf("Got error %#v", err)
|
||||||
|
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||||
|
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -163,6 +163,7 @@ type StatementList []struct {
|
||||||
func (s StatementList) Prepare(db *sql.DB) (err error) {
|
func (s StatementList) Prepare(db *sql.DB) (err error) {
|
||||||
for _, statement := range s {
|
for _, statement := range s {
|
||||||
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
|
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
|
||||||
|
err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can't have a self-signing or user-signing key without a master
|
// We can't have a self-signing or user-signing key without a master
|
||||||
// key, so make sure we have one of those.
|
// key, so make sure we have one of those. We will also only actually do
|
||||||
if !hasMasterKey {
|
// something if any of the specified keys in the request are different
|
||||||
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
// to what we've got in the database, to avoid generating key change
|
||||||
if err != nil {
|
// notifications unnecessarily.
|
||||||
res.Error = &api.KeyError{
|
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
if err != nil {
|
||||||
}
|
res.Error = &api.KeyError{
|
||||||
return
|
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||||
}
|
}
|
||||||
|
return
|
||||||
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we still can't find a master key for the user then stop the upload.
|
// If we still can't find a master key for the user then stop the upload.
|
||||||
// This satisfies the "Fails to upload self-signing key without master key" test.
|
// This satisfies the "Fails to upload self-signing key without master key" test.
|
||||||
if !hasMasterKey {
|
if !hasMasterKey {
|
||||||
res.Error = &api.KeyError{
|
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
|
||||||
Err: "No master key was found",
|
res.Error = &api.KeyError{
|
||||||
IsMissingParam: true,
|
Err: "No master key was found",
|
||||||
|
IsMissingParam: true,
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if anything actually changed compared to what we have in the database.
|
||||||
|
changed := false
|
||||||
|
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
|
||||||
|
gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||||
|
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||||
|
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||||
|
} {
|
||||||
|
old, gotOld := existingKeys[purpose]
|
||||||
|
new, gotNew := toStore[purpose]
|
||||||
|
if gotOld != gotNew {
|
||||||
|
// A new key purpose has been specified that we didn't know before,
|
||||||
|
// or one has been removed.
|
||||||
|
changed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !bytes.Equal(old, new) {
|
||||||
|
// One of the existing keys for a purpose we already knew about has
|
||||||
|
// changed.
|
||||||
|
changed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct {
|
||||||
staleUsers map[string]bool
|
staleUsers map[string]bool
|
||||||
prevIDsExist func(string, []int) bool
|
prevIDsExist func(string, []int) bool
|
||||||
storedKeys []api.DeviceMessage
|
storedKeys []api.DeviceMessage
|
||||||
|
mu sync.Mutex // protect staleUsers
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
// If no domains are given, all user IDs with stale device lists are returned.
|
// If no domains are given, all user IDs with stale device lists are returned.
|
||||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
var result []string
|
var result []string
|
||||||
for userID := range d.staleUsers {
|
for userID, isStale := range d.staleUsers {
|
||||||
|
if !isStale {
|
||||||
|
continue
|
||||||
|
}
|
||||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -75,10 +81,18 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do
|
||||||
|
|
||||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
d.staleUsers[userID] = isStale
|
d.staleUsers[userID] = isStale
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
return d.staleUsers[userID]
|
||||||
|
}
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys.
|
||||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||||
|
@ -161,7 +175,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
||||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||||
}
|
}
|
||||||
if db.staleUsers[event.UserID] {
|
if db.isStale(event.UserID) {
|
||||||
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -235,7 +249,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// Now we should have a fresh list and the keys and emitted something
|
// Now we should have a fresh list and the keys and emitted something
|
||||||
if db.staleUsers[event.UserID] {
|
if db.isStale(event.UserID) {
|
||||||
t.Errorf("%s still marked as stale", event.UserID)
|
t.Errorf("%s still marked as stale", event.UserID)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||||
|
@ -247,3 +261,83 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
|
||||||
|
// update is still ongoing.
|
||||||
|
func TestDebounce(t *testing.T) {
|
||||||
|
t.Skipf("panic on closed channel on GHA")
|
||||||
|
db := &mockDeviceListUpdaterDatabase{
|
||||||
|
staleUsers: make(map[string]bool),
|
||||||
|
prevIDsExist: func(string, []int) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ap := &mockDeviceListUpdaterAPI{}
|
||||||
|
producer := &mockKeyChangeProducer{}
|
||||||
|
fedCh := make(chan *http.Response, 1)
|
||||||
|
srv := gomatrixserverlib.ServerName("example.com")
|
||||||
|
userID := "@alice:example.com"
|
||||||
|
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||||
|
incomingFedReq := make(chan struct{})
|
||||||
|
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
|
||||||
|
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||||
|
}
|
||||||
|
close(incomingFedReq)
|
||||||
|
return <-fedCh, nil
|
||||||
|
})
|
||||||
|
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
|
||||||
|
if err := updater.Start(); err != nil {
|
||||||
|
t.Fatalf("failed to start updater: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hit this 5 times
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
|
||||||
|
t.Errorf("ManualUpdate: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until the updater hits federation
|
||||||
|
select {
|
||||||
|
case <-incomingFedReq:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timed out waiting for updater to hit federation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// user should be marked as stale
|
||||||
|
if !db.isStale(userID) {
|
||||||
|
t.Errorf("user %s not marked as stale", userID)
|
||||||
|
}
|
||||||
|
// now send the response over federation
|
||||||
|
fedCh <- &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader(`
|
||||||
|
{
|
||||||
|
"user_id": "` + userID + `",
|
||||||
|
"stream_id": 5,
|
||||||
|
"devices": [
|
||||||
|
{
|
||||||
|
"device_id": "JLAFKJWSCS",
|
||||||
|
"keys": ` + keyJSON + `,
|
||||||
|
"device_display_name": "Mobile Phone"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`)),
|
||||||
|
}
|
||||||
|
close(fedCh)
|
||||||
|
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
|
||||||
|
// and should panic with read on a closed channel
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// user is no longer stale now
|
||||||
|
if db.isStale(userID) {
|
||||||
|
t.Errorf("user %s is marked as stale", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -53,8 +53,7 @@ func Setup(
|
||||||
) {
|
) {
|
||||||
rateLimits := httputil.NewRateLimits(rateLimit)
|
rateLimits := httputil.NewRateLimits(rateLimit)
|
||||||
|
|
||||||
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
|
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
|
||||||
v1mux := publicAPIMux.PathPrefix("/v1").Subrouter()
|
|
||||||
|
|
||||||
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
||||||
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
||||||
|
@ -77,21 +76,18 @@ func Setup(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
r0mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
v3mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
||||||
r0mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v1mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
|
||||||
|
|
||||||
activeRemoteRequests := &types.ActiveRemoteRequests{
|
activeRemoteRequests := &types.ActiveRemoteRequests{
|
||||||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||||
}
|
}
|
||||||
|
|
||||||
downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
|
downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
|
||||||
r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||||
r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
|
|
||||||
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
|
|
||||||
|
|
||||||
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||||
makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
|
@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
|
||||||
|
|
||||||
type QuerySharedUsersRequest struct {
|
type QuerySharedUsersRequest struct {
|
||||||
UserID string
|
UserID string
|
||||||
|
OtherUserIDs []string
|
||||||
ExcludeRoomIDs []string
|
ExcludeRoomIDs []string
|
||||||
IncludeRoomIDs []string
|
IncludeRoomIDs []string
|
||||||
}
|
}
|
||||||
|
@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryCurrentStateRequest struct {
|
type QueryCurrentStateRequest struct {
|
||||||
RoomID string
|
RoomID string
|
||||||
|
AllowWildcards bool
|
||||||
|
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
|
||||||
|
// state events with that event type.
|
||||||
StateTuples []gomatrixserverlib.StateKeyTuple
|
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,12 +51,8 @@ func SendEventWithState(
|
||||||
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
|
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
|
||||||
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
||||||
) error {
|
) error {
|
||||||
outliers, err := state.Events(event.RoomVersion)
|
outliers := state.Events(event.RoomVersion)
|
||||||
if err != nil {
|
ires := make([]InputRoomEvent, 0, len(outliers))
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ires []InputRoomEvent
|
|
||||||
for _, outlier := range outliers {
|
for _, outlier := range outliers {
|
||||||
if haveEventIDs[outlier.EventID()] {
|
if haveEventIDs[outlier.EventID()] {
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -23,6 +23,21 @@ type parsedRespState struct {
|
||||||
StateEvents []*gomatrixserverlib.Event
|
StateEvents []*gomatrixserverlib.Event
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
|
||||||
|
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
|
||||||
|
for i, event := range p.AuthEvents {
|
||||||
|
eventsByID[event.EventID()] = p.AuthEvents[i]
|
||||||
|
}
|
||||||
|
for i, event := range p.StateEvents {
|
||||||
|
eventsByID[event.EventID()] = p.StateEvents[i]
|
||||||
|
}
|
||||||
|
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
|
||||||
|
for _, event := range eventsByID {
|
||||||
|
allEvents = append(allEvents, event)
|
||||||
|
}
|
||||||
|
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
|
||||||
|
}
|
||||||
|
|
||||||
type missingStateReq struct {
|
type missingStateReq struct {
|
||||||
origin gomatrixserverlib.ServerName
|
origin gomatrixserverlib.ServerName
|
||||||
db storage.Database
|
db storage.Database
|
||||||
|
@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState(
|
||||||
t.hadEventsMutex.Unlock()
|
t.hadEventsMutex.Unlock()
|
||||||
|
|
||||||
sendOutliers := func(resolvedState *parsedRespState) error {
|
sendOutliers := func(resolvedState *parsedRespState) error {
|
||||||
outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion)
|
outliers := resolvedState.Events()
|
||||||
if oerr != nil {
|
outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers))
|
||||||
return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr)
|
|
||||||
}
|
|
||||||
var outlierRoomEvents []api.InputRoomEvent
|
|
||||||
for _, outlier := range outliers {
|
for _, outlier := range outliers {
|
||||||
if hadEvents[outlier.EventID()] {
|
if hadEvents[outlier.EventID()] {
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms(
|
||||||
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||||
for _, tuple := range req.StateTuples {
|
for _, tuple := range req.StateTuples {
|
||||||
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
if tuple.StateKey == "*" && req.AllowWildcards {
|
||||||
if err != nil {
|
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
if ev != nil {
|
}
|
||||||
res.StateEvents[tuple] = ev
|
for _, e := range events {
|
||||||
|
res.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: e.Type(),
|
||||||
|
StateKey: *e.StateKey(),
|
||||||
|
}] = e
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ev != nil {
|
||||||
|
res.StateEvents[tuple] = ev
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
|
||||||
}
|
}
|
||||||
roomIDs = roomIDs[:j]
|
roomIDs = roomIDs[:j]
|
||||||
|
|
||||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -146,13 +146,14 @@ type Database interface {
|
||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
||||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
|
||||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||||
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||||
|
|
|
@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||||
`
|
`
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
|
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
|
@ -322,13 +323,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNIDs []types.RoomNID,
|
roomNIDs []types.RoomNID,
|
||||||
|
userNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]int, error) {
|
) (map[types.EventStateKeyNID]int, error) {
|
||||||
roomIDarray := make([]int64, len(roomNIDs))
|
|
||||||
for i := range roomNIDs {
|
|
||||||
roomIDarray[i] = int64(roomNIDs[i])
|
|
||||||
}
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -979,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
|
||||||
|
// if there are no events with this event type.
|
||||||
|
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return nil, fmt.Errorf("room %s doesn't exist", roomID)
|
||||||
|
}
|
||||||
|
// e.g invited rooms
|
||||||
|
if roomInfo.IsStub {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// No rooms have an event of this type, otherwise we'd have an event type NID
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.EventTypeNID == eventTypeNID {
|
||||||
|
eventNIDs = append(eventNIDs, e.EventNID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
eventIDs = map[types.EventNID]string{}
|
||||||
|
}
|
||||||
|
// return the events requested
|
||||||
|
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(eventPairs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var result []*gomatrixserverlib.HeaderedEvent
|
||||||
|
for _, pair := range eventPairs {
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, ev.Headered(roomInfo.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||||
var membershipState tables.MembershipState
|
var membershipState tables.MembershipState
|
||||||
|
@ -1106,13 +1161,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
||||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
|
||||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
|
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
|
||||||
|
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
|
||||||
|
for id, nid := range userNIDsMap {
|
||||||
|
userNIDs = append(userNIDs, nid)
|
||||||
|
nidToUserID[nid] = id
|
||||||
|
}
|
||||||
|
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1122,13 +1187,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
||||||
stateKeyNIDs[i] = nid
|
stateKeyNIDs[i] = nid
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(nidToUserID) != len(userNIDToCount) {
|
|
||||||
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
|
|
||||||
}
|
|
||||||
result := make(map[string]int, len(userNIDToCount))
|
result := make(map[string]int, len(userNIDToCount))
|
||||||
for nid, count := range userNIDToCount {
|
for nid, count := range userNIDToCount {
|
||||||
result[nidToUserID[nid]] = count
|
result[nidToUserID[nid]] = count
|
||||||
|
|
|
@ -42,7 +42,8 @@ const membershipSchema = `
|
||||||
`
|
`
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
|
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
|
@ -296,18 +297,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
|
||||||
for i, v := range roomNIDs {
|
for _, v := range roomNIDs {
|
||||||
iRoomNIDs[i] = v
|
params = append(params, v)
|
||||||
}
|
}
|
||||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
for _, v := range userNIDs {
|
||||||
|
params = append(params, v)
|
||||||
|
}
|
||||||
|
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||||
|
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
rows, err = txn.QueryContext(ctx, query, params...)
|
||||||
} else {
|
} else {
|
||||||
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
rows, err = s.db.QueryContext(ctx, query, params...)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -128,9 +128,8 @@ type Membership interface {
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
||||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||||
// counts of how many rooms they are joined.
|
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
|
||||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||||
|
|
63
run-sytest.sh
Executable file
63
run-sytest.sh
Executable file
|
@ -0,0 +1,63 @@
|
||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
|
||||||
|
# locally, the Docker image is rebuilt first.
|
||||||
|
#
|
||||||
|
# Logs are stored in ../sytestout/logs.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
main() {
|
||||||
|
local tag=buster
|
||||||
|
local base_image=debian:$tag
|
||||||
|
local runargs=()
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
if [ -d ../sytest ]; then
|
||||||
|
local tmpdir
|
||||||
|
tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
|
||||||
|
trap "rm -r '$tmpdir'" EXIT
|
||||||
|
|
||||||
|
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
|
||||||
|
echo "Re-building ../sytest Docker images..."
|
||||||
|
|
||||||
|
local status
|
||||||
|
(
|
||||||
|
cd ../sytest
|
||||||
|
|
||||||
|
docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
|
||||||
|
docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
|
||||||
|
) &>"$tmpdir/buildlog" || status=$?
|
||||||
|
if (( status != 0 )); then
|
||||||
|
# Docker is very verbose, and we don't really care about
|
||||||
|
# building SyTest. So we accumulate and only output on
|
||||||
|
# failure.
|
||||||
|
cat "$tmpdir/buildlog" >&2
|
||||||
|
return $status
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
runargs+=( -v "$PWD/../sytest:/sytest:ro" )
|
||||||
|
fi
|
||||||
|
if [ -n "$SYTEST_POSTGRES" ]; then
|
||||||
|
runargs+=( -e POSTGRES=1 )
|
||||||
|
fi
|
||||||
|
|
||||||
|
local sytestout=$PWD/../sytestout
|
||||||
|
mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
|
||||||
|
docker run \
|
||||||
|
--rm \
|
||||||
|
--name "sytest-dendrite-${LOGNAME}" \
|
||||||
|
-e LOGS_USER=$(id -u) \
|
||||||
|
-e LOGS_GROUP=$(id -g) \
|
||||||
|
-v "$PWD:/src/:ro" \
|
||||||
|
-v "$sytestout/logs:/logs/" \
|
||||||
|
-v "$sytestout/cache/go-build:/root/.cache/go-build" \
|
||||||
|
-v "$sytestout/cache/go-pkg:/gopath/pkg" \
|
||||||
|
"${runargs[@]}" \
|
||||||
|
matrixdotorg/sytest-dendrite:latest "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
main "$@"
|
|
@ -31,6 +31,7 @@ import (
|
||||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
|
@ -270,6 +271,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
|
||||||
|
func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
||||||
|
return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||||
// be called once per component.
|
// be called once per component.
|
||||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||||
|
|
|
@ -205,6 +205,11 @@ user_api:
|
||||||
max_open_conns: 100
|
max_open_conns: 100
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
pusher_database:
|
||||||
|
connection_string: file:pushserver.db
|
||||||
|
max_open_conns: 100
|
||||||
|
max_idle_conns: 2
|
||||||
|
conn_max_lifetime: -1
|
||||||
tracing:
|
tracing:
|
||||||
enabled: false
|
enabled: false
|
||||||
jaeger:
|
jaeger:
|
||||||
|
|
|
@ -13,6 +13,9 @@ type UserAPI struct {
|
||||||
// The length of time an OpenID token is condidered valid in milliseconds
|
// The length of time an OpenID token is condidered valid in milliseconds
|
||||||
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
||||||
|
|
||||||
|
// Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
|
||||||
|
PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
|
||||||
|
|
||||||
// The Account database stores the login details and account information
|
// The Account database stores the login details and account information
|
||||||
// for local users. It is accessed by the UserAPI.
|
// for local users. It is accessed by the UserAPI.
|
||||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||||
|
|
|
@ -18,7 +18,10 @@ var (
|
||||||
OutputKeyChangeEvent = "OutputKeyChangeEvent"
|
OutputKeyChangeEvent = "OutputKeyChangeEvent"
|
||||||
OutputTypingEvent = "OutputTypingEvent"
|
OutputTypingEvent = "OutputTypingEvent"
|
||||||
OutputClientData = "OutputClientData"
|
OutputClientData = "OutputClientData"
|
||||||
|
OutputNotificationData = "OutputNotificationData"
|
||||||
OutputReceiptEvent = "OutputReceiptEvent"
|
OutputReceiptEvent = "OutputReceiptEvent"
|
||||||
|
OutputStreamEvent = "OutputStreamEvent"
|
||||||
|
OutputReadUpdate = "OutputReadUpdate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var streams = []*nats.StreamConfig{
|
var streams = []*nats.StreamConfig{
|
||||||
|
@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
|
||||||
Retention: nats.InterestPolicy,
|
Retention: nats.InterestPolicy,
|
||||||
Storage: nats.FileStorage,
|
Storage: nats.FileStorage,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: OutputNotificationData,
|
||||||
|
Retention: nats.InterestPolicy,
|
||||||
|
Storage: nats.FileStorage,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: OutputStreamEvent,
|
||||||
|
Retention: nats.InterestPolicy,
|
||||||
|
Storage: nats.FileStorage,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: OutputReadUpdate,
|
||||||
|
Retention: nats.InterestPolicy,
|
||||||
|
Storage: nats.FileStorage,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
|
||||||
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
|
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
|
||||||
m.FedClient, m.RoomserverAPI,
|
m.FedClient, m.RoomserverAPI,
|
||||||
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
|
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
|
||||||
m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,
|
m.FederationAPI, m.UserAPI, m.KeyAPI,
|
||||||
&m.Config.MSCs,
|
m.ExtPublicRoomsProvider, &m.Config.MSCs,
|
||||||
)
|
)
|
||||||
federationapi.AddPublicRoutes(
|
federationapi.AddPublicRoutes(
|
||||||
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
|
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
|
||||||
|
|
|
@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
|
||||||
AuthEvents: res.AuthChain,
|
AuthEvents: res.AuthChain,
|
||||||
StateEvents: stateEvents,
|
StateEvents: stateEvents,
|
||||||
}
|
}
|
||||||
eventsInOrder, err := respState.Events(rc.roomVersion)
|
eventsInOrder := respState.Events(rc.roomVersion)
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
|
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
|
||||||
// as may the threaded events.
|
// as may the threaded events.
|
||||||
var ires []roomserver.InputRoomEvent
|
var ires []roomserver.InputRoomEvent
|
||||||
|
@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// we've got the data by this point so use a background context
|
// we've got the data by this point so use a background context
|
||||||
err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
|
err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
|
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,17 +18,19 @@ package msc2946
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
fs "github.com/matrix-org/dendrite/federationapi/api"
|
fs "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/hooks"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
|
@ -39,42 +41,27 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ConstCreateEventContentKey = "type"
|
ConstCreateEventContentKey = "type"
|
||||||
ConstSpaceChildEventType = "m.space.child"
|
ConstCreateEventContentValueSpace = "m.space"
|
||||||
ConstSpaceParentEventType = "m.space.parent"
|
ConstSpaceChildEventType = "m.space.child"
|
||||||
|
ConstSpaceParentEventType = "m.space.parent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults sets the request defaults
|
type MSC2946ClientResponse struct {
|
||||||
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
|
Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"`
|
||||||
r.Limit = 2000
|
NextBatch string `json:"next_batch,omitempty"`
|
||||||
r.MaxRoomsPerSpace = -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable this MSC
|
// Enable this MSC
|
||||||
func Enable(
|
func Enable(
|
||||||
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
|
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
|
||||||
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
|
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
|
||||||
) error {
|
) error {
|
||||||
db, err := NewDatabase(&base.Cfg.MSCs.Database)
|
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName))
|
||||||
if err != nil {
|
base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||||
return fmt.Errorf("cannot enable MSC2946: %w", err)
|
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||||
}
|
|
||||||
hooks.Enable()
|
|
||||||
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
|
|
||||||
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
|
||||||
hookErr := db.StoreReference(context.Background(), he)
|
|
||||||
if hookErr != nil {
|
|
||||||
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
|
|
||||||
"failed to StoreReference",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces",
|
fedAPI := httputil.MakeExternalAPI(
|
||||||
httputil.MakeAuthAPI("spaces", userAPI, base.Cfg.Global.UserConsentOptions, false, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
|
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
|
||||||
|
|
||||||
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
|
|
||||||
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
|
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
|
||||||
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
|
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
|
||||||
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
|
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
|
||||||
|
@ -88,252 +75,308 @@ func Enable(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
roomID := params["roomID"]
|
roomID := params["roomID"]
|
||||||
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName)
|
return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName)
|
||||||
},
|
},
|
||||||
)).Methods(http.MethodPost, http.MethodOptions)
|
)
|
||||||
|
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
|
||||||
|
base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func federatedSpacesHandler(
|
func federatedSpacesHandler(
|
||||||
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database,
|
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string,
|
||||||
|
cache caching.SpaceSummaryRoomsCache,
|
||||||
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
||||||
thisServer gomatrixserverlib.ServerName,
|
thisServer gomatrixserverlib.ServerName,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
inMemoryBatchCache := make(map[string]set)
|
u, err := url.Parse(fedReq.RequestURI())
|
||||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
if err != nil {
|
||||||
Defaults(&r)
|
|
||||||
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: 400,
|
||||||
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
JSON: jsonerror.InvalidParam("bad request uri"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w := walker{
|
|
||||||
req: &r,
|
|
||||||
rootRoomID: roomID,
|
|
||||||
serverName: fedReq.Origin(),
|
|
||||||
thisServer: thisServer,
|
|
||||||
ctx: ctx,
|
|
||||||
|
|
||||||
db: db,
|
w := walker{
|
||||||
rsAPI: rsAPI,
|
rootRoomID: roomID,
|
||||||
fsAPI: fsAPI,
|
serverName: fedReq.Origin(),
|
||||||
inMemoryBatchCache: inMemoryBatchCache,
|
thisServer: thisServer,
|
||||||
}
|
ctx: ctx,
|
||||||
res := w.walk()
|
cache: cache,
|
||||||
return util.JSONResponse{
|
suggestedOnly: u.Query().Get("suggested_only") == "true",
|
||||||
Code: 200,
|
limit: 1000,
|
||||||
JSON: res,
|
// The main difference is that it does not recurse into spaces and does not support pagination.
|
||||||
|
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
|
||||||
|
maxDepth: 1,
|
||||||
|
|
||||||
|
rsAPI: rsAPI,
|
||||||
|
fsAPI: fsAPI,
|
||||||
|
// inline cache as we don't have pagination in federation mode
|
||||||
|
paginationCache: make(map[string]paginationInfo),
|
||||||
}
|
}
|
||||||
|
return w.walk()
|
||||||
}
|
}
|
||||||
|
|
||||||
func spacesHandler(
|
func spacesHandler(
|
||||||
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
rsAPI roomserver.RoomserverInternalAPI,
|
||||||
|
fsAPI fs.FederationInternalAPI,
|
||||||
|
cache caching.SpaceSummaryRoomsCache,
|
||||||
thisServer gomatrixserverlib.ServerName,
|
thisServer gomatrixserverlib.ServerName,
|
||||||
) func(*http.Request, *userapi.Device) util.JSONResponse {
|
) func(*http.Request, *userapi.Device) util.JSONResponse {
|
||||||
|
// declared outside the returned handler so it persists between calls
|
||||||
|
// TODO: clear based on... time?
|
||||||
|
paginationCache := make(map[string]paginationInfo)
|
||||||
|
|
||||||
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
inMemoryBatchCache := make(map[string]set)
|
|
||||||
// Extract the room ID from the request. Sanity check request data.
|
// Extract the room ID from the request. Sanity check request data.
|
||||||
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
roomID := params["roomID"]
|
roomID := params["roomID"]
|
||||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
|
||||||
Defaults(&r)
|
|
||||||
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
|
||||||
return *resErr
|
|
||||||
}
|
|
||||||
w := walker{
|
w := walker{
|
||||||
req: &r,
|
suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
|
||||||
rootRoomID: roomID,
|
limit: parseInt(req.URL.Query().Get("limit"), 1000),
|
||||||
caller: device,
|
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
|
||||||
thisServer: thisServer,
|
paginationToken: req.URL.Query().Get("from"),
|
||||||
ctx: req.Context(),
|
rootRoomID: roomID,
|
||||||
|
caller: device,
|
||||||
|
thisServer: thisServer,
|
||||||
|
ctx: req.Context(),
|
||||||
|
cache: cache,
|
||||||
|
|
||||||
db: db,
|
rsAPI: rsAPI,
|
||||||
rsAPI: rsAPI,
|
fsAPI: fsAPI,
|
||||||
fsAPI: fsAPI,
|
paginationCache: paginationCache,
|
||||||
inMemoryBatchCache: inMemoryBatchCache,
|
|
||||||
}
|
|
||||||
res := w.walk()
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 200,
|
|
||||||
JSON: res,
|
|
||||||
}
|
}
|
||||||
|
return w.walk()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type paginationInfo struct {
|
||||||
|
processed set
|
||||||
|
unvisited []roomVisit
|
||||||
|
}
|
||||||
|
|
||||||
type walker struct {
|
type walker struct {
|
||||||
req *gomatrixserverlib.MSC2946SpacesRequest
|
rootRoomID string
|
||||||
rootRoomID string
|
caller *userapi.Device
|
||||||
caller *userapi.Device
|
serverName gomatrixserverlib.ServerName
|
||||||
serverName gomatrixserverlib.ServerName
|
thisServer gomatrixserverlib.ServerName
|
||||||
thisServer gomatrixserverlib.ServerName
|
rsAPI roomserver.RoomserverInternalAPI
|
||||||
db Database
|
fsAPI fs.FederationInternalAPI
|
||||||
rsAPI roomserver.RoomserverInternalAPI
|
ctx context.Context
|
||||||
fsAPI fs.FederationInternalAPI
|
cache caching.SpaceSummaryRoomsCache
|
||||||
ctx context.Context
|
suggestedOnly bool
|
||||||
|
limit int
|
||||||
|
maxDepth int
|
||||||
|
paginationToken string
|
||||||
|
|
||||||
// user ID|device ID|batch_num => event/room IDs sent to client
|
paginationCache map[string]paginationInfo
|
||||||
inMemoryBatchCache map[string]set
|
mu sync.Mutex
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *walker) roomIsExcluded(roomID string) bool {
|
func (w *walker) newPaginationCache() (string, paginationInfo) {
|
||||||
for _, exclRoom := range w.req.ExcludeRooms {
|
p := paginationInfo{
|
||||||
if exclRoom == roomID {
|
processed: make(set),
|
||||||
return true
|
unvisited: nil,
|
||||||
|
}
|
||||||
|
tok := uuid.NewString()
|
||||||
|
return tok, p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
p := w.paginationCache[paginationToken]
|
||||||
|
return &p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.paginationCache[paginationToken] = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
type roomVisit struct {
|
||||||
|
roomID string
|
||||||
|
depth int
|
||||||
|
vias []string // vias to query this room by
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) walk() util.JSONResponse {
|
||||||
|
if !w.authorised(w.rootRoomID) {
|
||||||
|
if w.caller != nil {
|
||||||
|
// CS API format
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("room is unknown/forbidden"),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// SS API format
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("room is unknown/forbidden"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *walker) callerID() string {
|
var discoveredRooms []gomatrixserverlib.MSC2946Room
|
||||||
if w.caller != nil {
|
|
||||||
return w.caller.UserID + "|" + w.caller.ID
|
var cache *paginationInfo
|
||||||
|
if w.paginationToken != "" {
|
||||||
|
cache = w.loadPaginationCache(w.paginationToken)
|
||||||
|
if cache == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 400,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("invalid from"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tok, c := w.newPaginationCache()
|
||||||
|
cache = &c
|
||||||
|
w.paginationToken = tok
|
||||||
|
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
|
||||||
|
c.unvisited = append(c.unvisited, roomVisit{
|
||||||
|
roomID: w.rootRoomID,
|
||||||
|
depth: 0,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return string(w.serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *walker) alreadySent(id string) bool {
|
processed := cache.processed
|
||||||
w.mu.Lock()
|
unvisited := cache.unvisited
|
||||||
defer w.mu.Unlock()
|
|
||||||
m, ok := w.inMemoryBatchCache[w.callerID()]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return m[id]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *walker) markSent(id string) {
|
// Depth first -> stack data structure
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
m := w.inMemoryBatchCache[w.callerID()]
|
|
||||||
if m == nil {
|
|
||||||
m = make(set)
|
|
||||||
}
|
|
||||||
m[id] = true
|
|
||||||
w.inMemoryBatchCache[w.callerID()] = m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
|
|
||||||
var res gomatrixserverlib.MSC2946SpacesResponse
|
|
||||||
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
|
|
||||||
unvisited := []string{w.rootRoomID}
|
|
||||||
processed := make(set)
|
|
||||||
for len(unvisited) > 0 {
|
for len(unvisited) > 0 {
|
||||||
roomID := unvisited[0]
|
if len(discoveredRooms) >= w.limit {
|
||||||
unvisited = unvisited[1:]
|
break
|
||||||
// If this room has already been processed, skip. NB: do not remember this between calls
|
}
|
||||||
if processed[roomID] || roomID == "" {
|
|
||||||
|
// pop the stack
|
||||||
|
rv := unvisited[len(unvisited)-1]
|
||||||
|
unvisited = unvisited[:len(unvisited)-1]
|
||||||
|
// If this room has already been processed, skip.
|
||||||
|
// If this room exceeds the specified depth, skip.
|
||||||
|
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark this room as processed.
|
// Mark this room as processed.
|
||||||
processed[roomID] = true
|
processed.set(rv.roomID)
|
||||||
|
|
||||||
|
// if this room is not a space room, skip.
|
||||||
|
var roomType string
|
||||||
|
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
|
||||||
|
if create != nil {
|
||||||
|
// escape the `.`s so gjson doesn't think it's nested
|
||||||
|
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
|
||||||
|
}
|
||||||
|
|
||||||
// Collect rooms/events to send back (either locally or fetched via federation)
|
// Collect rooms/events to send back (either locally or fetched via federation)
|
||||||
var discoveredRooms []gomatrixserverlib.MSC2946Room
|
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
|
||||||
var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent
|
|
||||||
|
|
||||||
// If we know about this room and the caller is authorised (joined/world_readable) then pull
|
// If we know about this room and the caller is authorised (joined/world_readable) then pull
|
||||||
// events locally
|
// events locally
|
||||||
if w.roomExists(roomID) && w.authorised(roomID) {
|
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) {
|
||||||
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
|
// Get all `m.space.child` state events for this room
|
||||||
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`)
|
events, err := w.childReferences(rv.roomID)
|
||||||
// this room. This requires servers to store reverse lookups.
|
|
||||||
events, err := w.references(roomID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
|
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
discoveredEvents = events
|
discoveredChildEvents = events
|
||||||
|
|
||||||
pubRoom := w.publicRoomsChunk(roomID)
|
pubRoom := w.publicRoomsChunk(rv.roomID)
|
||||||
roomType := ""
|
|
||||||
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
|
|
||||||
if create != nil {
|
|
||||||
// escape the `.`s so gjson doesn't think it's nested
|
|
||||||
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
|
|
||||||
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
|
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
|
||||||
PublicRoom: *pubRoom,
|
PublicRoom: *pubRoom,
|
||||||
NumRefs: len(discoveredEvents),
|
RoomType: roomType,
|
||||||
RoomType: roomType,
|
ChildrenState: events,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// attempt to query this room over federation, as either we've never heard of it before
|
// attempt to query this room over federation, as either we've never heard of it before
|
||||||
// or we've left it and hence are not authorised (but info may be exposed regardless)
|
// or we've left it and hence are not authorised (but info may be exposed regardless)
|
||||||
fedRes, err := w.federatedRoomInfo(roomID)
|
fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces")
|
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if fedRes != nil {
|
if fedRes != nil {
|
||||||
discoveredRooms = fedRes.Rooms
|
discoveredChildEvents = fedRes.Room.ChildrenState
|
||||||
discoveredEvents = fedRes.Events
|
discoveredRooms = append(discoveredRooms, fedRes.Room)
|
||||||
|
if len(fedRes.Children) > 0 {
|
||||||
|
discoveredRooms = append(discoveredRooms, fedRes.Children...)
|
||||||
|
}
|
||||||
|
// mark this room as a space room as the federated server responded.
|
||||||
|
// we need to do this so we add the children of this room to the unvisited stack
|
||||||
|
// as these children may be rooms we do know about.
|
||||||
|
roomType = ConstCreateEventContentValueSpace
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this room has not ever been in `rooms` (across multiple requests), send it now
|
// don't walk the children
|
||||||
for _, room := range discoveredRooms {
|
// if the parent is not a space room
|
||||||
if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) {
|
if roomType != ConstCreateEventContentValueSpace {
|
||||||
res.Rooms = append(res.Rooms, room)
|
continue
|
||||||
w.markSent(room.RoomID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uniqueRooms := make(set)
|
// For each referenced room ID in the child events being returned to the caller
|
||||||
|
|
||||||
// If this is the root room from the original request, insert all these events into `events` if
|
|
||||||
// they haven't been added before (across multiple requests).
|
|
||||||
if w.rootRoomID == roomID {
|
|
||||||
for _, ev := range discoveredEvents {
|
|
||||||
if !w.alreadySent(eventKey(&ev)) {
|
|
||||||
res.Events = append(res.Events, ev)
|
|
||||||
uniqueRooms[ev.RoomID] = true
|
|
||||||
uniqueRooms[spaceTargetStripped(&ev)] = true
|
|
||||||
w.markSent(eventKey(&ev))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
|
|
||||||
// are exceeded, stop adding events. If the event has already been added, do not add it again.
|
|
||||||
numAdded := 0
|
|
||||||
for _, ev := range discoveredEvents {
|
|
||||||
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if w.alreadySent(eventKey(&ev)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
|
|
||||||
// want to catch arrows which point to excluded rooms.
|
|
||||||
if w.roomIsExcluded(ev.RoomID) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
res.Events = append(res.Events, ev)
|
|
||||||
uniqueRooms[ev.RoomID] = true
|
|
||||||
uniqueRooms[spaceTargetStripped(&ev)] = true
|
|
||||||
w.markSent(eventKey(&ev))
|
|
||||||
// we don't distinguish between child state events and parent state events for the purposes of
|
|
||||||
// max_rooms_per_space, maybe we should?
|
|
||||||
numAdded++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each referenced room ID in the events being returned to the caller (both parent and child)
|
|
||||||
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
|
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
|
||||||
for roomID := range uniqueRooms {
|
// We need to invert the order here because the child events are lo->hi on the timestamp,
|
||||||
unvisited = append(unvisited, roomID)
|
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
|
||||||
|
// insert the highest timestamp last in a stack.
|
||||||
|
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
|
||||||
|
spaceContent := struct {
|
||||||
|
Via []string `json:"via"`
|
||||||
|
}{}
|
||||||
|
ev := discoveredChildEvents[i]
|
||||||
|
_ = json.Unmarshal(ev.Content, &spaceContent)
|
||||||
|
unvisited = append(unvisited, roomVisit{
|
||||||
|
roomID: ev.StateKey,
|
||||||
|
depth: rv.depth + 1,
|
||||||
|
vias: spaceContent.Via,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &res
|
|
||||||
|
if len(unvisited) > 0 {
|
||||||
|
// we still have more rooms so we need to send back a pagination token,
|
||||||
|
// we probably hit a room limit
|
||||||
|
cache.processed = processed
|
||||||
|
cache.unvisited = unvisited
|
||||||
|
w.storePaginationCache(w.paginationToken, *cache)
|
||||||
|
} else {
|
||||||
|
// clear the pagination token so we don't send it back to the client
|
||||||
|
// Note we do NOT nuke the cache just in case this response is lost
|
||||||
|
// and the client retries it.
|
||||||
|
w.paginationToken = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.caller != nil {
|
||||||
|
// return CS API format
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: MSC2946ClientResponse{
|
||||||
|
Rooms: discoveredRooms,
|
||||||
|
NextBatch: w.paginationToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// return SS API format
|
||||||
|
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
|
||||||
|
if len(discoveredRooms) == 0 {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("room is unknown/forbidden"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: gomatrixserverlib.MSC2946SpacesResponse{
|
||||||
|
Room: discoveredRooms[0],
|
||||||
|
Children: discoveredRooms[1:],
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
|
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
|
||||||
|
@ -366,46 +409,41 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
|
||||||
|
|
||||||
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
|
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
|
||||||
// unsuccessful.
|
// unsuccessful.
|
||||||
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
|
func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
|
||||||
// only do federated requests for client requests
|
// only do federated requests for client requests
|
||||||
if w.caller == nil {
|
if w.caller == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
// extract events which point to this room ID and extract their vias
|
resp, ok := w.cache.GetSpaceSummary(roomID)
|
||||||
events, err := w.db.References(w.ctx, roomID)
|
if ok {
|
||||||
if err != nil {
|
util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID)
|
||||||
return nil, fmt.Errorf("failed to get References events: %w", err)
|
return &resp, nil
|
||||||
}
|
}
|
||||||
vias := make(set)
|
util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias)
|
||||||
for _, ev := range events {
|
|
||||||
if ev.StateKeyEquals(roomID) {
|
|
||||||
// event points at this room, extract vias
|
|
||||||
content := struct {
|
|
||||||
Vias []string `json:"via"`
|
|
||||||
}{}
|
|
||||||
if err = json.Unmarshal(ev.Content(), &content); err != nil {
|
|
||||||
continue // silently ignore corrupted state events
|
|
||||||
}
|
|
||||||
for _, v := range content.Vias {
|
|
||||||
vias[v] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
// query more of the spaces graph using these servers
|
// query more of the spaces graph using these servers
|
||||||
for serverName := range vias {
|
for _, serverName := range vias {
|
||||||
if serverName == string(w.thisServer) {
|
if serverName == string(w.thisServer) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{
|
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
|
||||||
Limit: w.req.Limit,
|
|
||||||
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
|
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// ensure nil slices are empty as we send this to the client sometimes
|
||||||
|
if res.Room.ChildrenState == nil {
|
||||||
|
res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
|
||||||
|
}
|
||||||
|
for i := 0; i < len(res.Children); i++ {
|
||||||
|
child := res.Children[i]
|
||||||
|
if child.ChildrenState == nil {
|
||||||
|
child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
|
||||||
|
}
|
||||||
|
res.Children[i] = child
|
||||||
|
}
|
||||||
|
w.cache.StoreSpaceSummary(roomID, res)
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -501,7 +539,7 @@ func (w *walker) authorisedUser(roomID string) bool {
|
||||||
hisVisEv := queryRes.StateEvents[hisVisTuple]
|
hisVisEv := queryRes.StateEvents[hisVisTuple]
|
||||||
if memberEv != nil {
|
if memberEv != nil {
|
||||||
membership, _ := memberEv.Membership()
|
membership, _ := memberEv.Membership()
|
||||||
if membership == gomatrixserverlib.Join {
|
if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -514,40 +552,85 @@ func (w *walker) authorisedUser(roomID string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// references returns all references pointing to or from this room.
|
// references returns all child references pointing to or from this room.
|
||||||
func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
|
func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
|
||||||
events, err := w.db.References(w.ctx, roomID)
|
createTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: gomatrixserverlib.MRoomCreate,
|
||||||
|
StateKey: "",
|
||||||
|
}
|
||||||
|
var res roomserver.QueryCurrentStateResponse
|
||||||
|
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
AllowWildcards: true,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
createTuple, {
|
||||||
|
EventType: ConstSpaceChildEventType,
|
||||||
|
StateKey: "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
|
|
||||||
for _, ev := range events {
|
// don't return any child refs if the room is not a space room
|
||||||
|
if res.StateEvents[createTuple] != nil {
|
||||||
|
// escape the `.`s so gjson doesn't think it's nested
|
||||||
|
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
|
||||||
|
if roomType != ConstCreateEventContentValueSpace {
|
||||||
|
return []gomatrixserverlib.MSC2946StrippedEvent{}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(res.StateEvents, createTuple)
|
||||||
|
|
||||||
|
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents))
|
||||||
|
for _, ev := range res.StateEvents {
|
||||||
|
content := gjson.ParseBytes(ev.Content())
|
||||||
// only return events that have a `via` key as per MSC1772
|
// only return events that have a `via` key as per MSC1772
|
||||||
// else we'll incorrectly walk redacted events (as the link
|
// else we'll incorrectly walk redacted events (as the link
|
||||||
// is in the state_key)
|
// is in the state_key)
|
||||||
if gjson.GetBytes(ev.Content(), "via").Exists() {
|
if content.Get("via").Exists() {
|
||||||
strip := stripped(ev.Event)
|
strip := stripped(ev.Event)
|
||||||
if strip == nil {
|
if strip == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// if suggested only and this child isn't suggested, skip it.
|
||||||
|
// if suggested only = false we include everything so don't need to check the content.
|
||||||
|
if w.suggestedOnly && !content.Get("suggested").Bool() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
el = append(el, *strip)
|
el = append(el, *strip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// sort by origin_server_ts as per MSC2946
|
||||||
|
sort.Slice(el, func(i, j int) bool {
|
||||||
|
return el[i].OriginServerTS < el[j].OriginServerTS
|
||||||
|
})
|
||||||
|
|
||||||
return el, nil
|
return el, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type set map[string]bool
|
type set map[string]struct{}
|
||||||
|
|
||||||
|
func (s set) set(val string) {
|
||||||
|
s[val] = struct{}{}
|
||||||
|
}
|
||||||
|
func (s set) isSet(val string) bool {
|
||||||
|
_, ok := s[val]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
|
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
|
||||||
if ev.StateKey() == nil {
|
if ev.StateKey() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &gomatrixserverlib.MSC2946StrippedEvent{
|
return &gomatrixserverlib.MSC2946StrippedEvent{
|
||||||
Type: ev.Type(),
|
Type: ev.Type(),
|
||||||
StateKey: *ev.StateKey(),
|
StateKey: *ev.StateKey(),
|
||||||
Content: ev.Content(),
|
Content: ev.Content(),
|
||||||
Sender: ev.Sender(),
|
Sender: ev.Sender(),
|
||||||
RoomID: ev.RoomID(),
|
RoomID: ev.RoomID(),
|
||||||
|
OriginServerTS: ev.OriginServerTS(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,3 +650,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string {
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseInt(intstr string, defaultVal int) int {
|
||||||
|
i, err := strconv.ParseInt(intstr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
return int(i)
|
||||||
|
}
|
||||||
|
|
|
@ -1,464 +0,0 @@
|
||||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package msc2946_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"encoding/json"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/matrix-org/dendrite/internal/hooks"
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
|
||||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
|
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
client = &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
roomVer = gomatrixserverlib.RoomVersionV6
|
|
||||||
)
|
|
||||||
|
|
||||||
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
|
|
||||||
// and a bit of recursion to subspaces. Makes a graph like:
|
|
||||||
// Root
|
|
||||||
// ____|_____
|
|
||||||
// | | |
|
|
||||||
// R1 R2 S1
|
|
||||||
// |_________
|
|
||||||
// | | |
|
|
||||||
// R3 R4 S2
|
|
||||||
// | <-- this link is just a parent, not a child
|
|
||||||
// R5
|
|
||||||
//
|
|
||||||
// Alice is not joined to R4, but R4 is "world_readable".
|
|
||||||
func TestMSC2946(t *testing.T) {
|
|
||||||
alice := "@alice:localhost"
|
|
||||||
// give access token to alice
|
|
||||||
nopUserAPI := &testUserAPI{
|
|
||||||
accessTokens: make(map[string]userapi.Device),
|
|
||||||
}
|
|
||||||
nopUserAPI.accessTokens["alice"] = userapi.Device{
|
|
||||||
AccessToken: "alice",
|
|
||||||
DisplayName: "Alice",
|
|
||||||
UserID: alice,
|
|
||||||
}
|
|
||||||
rootSpace := "!rootspace:localhost"
|
|
||||||
subSpaceS1 := "!subspaceS1:localhost"
|
|
||||||
subSpaceS2 := "!subspaceS2:localhost"
|
|
||||||
room1 := "!room1:localhost"
|
|
||||||
room2 := "!room2:localhost"
|
|
||||||
room3 := "!room3:localhost"
|
|
||||||
room4 := "!room4:localhost"
|
|
||||||
empty := ""
|
|
||||||
room5 := "!room5:localhost"
|
|
||||||
allRooms := []string{
|
|
||||||
rootSpace, subSpaceS1, subSpaceS2,
|
|
||||||
room1, room2, room3, room4, room5,
|
|
||||||
}
|
|
||||||
rootToR1 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: rootSpace,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &room1,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
rootToR2 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: rootSpace,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &room2,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
rootToS1 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: rootSpace,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &subSpaceS1,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s1ToR3 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: subSpaceS1,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &room3,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s1ToR4 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: subSpaceS1,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &room4,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s1ToS2 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: subSpaceS1,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &subSpaceS2,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
// This is a parent link only
|
|
||||||
s2ToR5 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: room5,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceParentEventType,
|
|
||||||
StateKey: &subSpaceS2,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"via": []string{"localhost"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
// history visibility for R4
|
|
||||||
r4HisVis := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: room4,
|
|
||||||
Sender: "@someone:localhost",
|
|
||||||
Type: gomatrixserverlib.MRoomHistoryVisibility,
|
|
||||||
StateKey: &empty,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"history_visibility": "world_readable",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
var joinEvents []*gomatrixserverlib.HeaderedEvent
|
|
||||||
for _, roomID := range allRooms {
|
|
||||||
if roomID == room4 {
|
|
||||||
continue // not joined to that room
|
|
||||||
}
|
|
||||||
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: roomID,
|
|
||||||
Sender: alice,
|
|
||||||
StateKey: &alice,
|
|
||||||
Type: gomatrixserverlib.MRoomMember,
|
|
||||||
Content: map[string]interface{}{
|
|
||||||
"membership": "join",
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
roomNameTuple := gomatrixserverlib.StateKeyTuple{
|
|
||||||
EventType: "m.room.name",
|
|
||||||
StateKey: "",
|
|
||||||
}
|
|
||||||
hisVisTuple := gomatrixserverlib.StateKeyTuple{
|
|
||||||
EventType: "m.room.history_visibility",
|
|
||||||
StateKey: "",
|
|
||||||
}
|
|
||||||
nopRsAPI := &testRoomserverAPI{
|
|
||||||
joinEvents: joinEvents,
|
|
||||||
events: map[string]*gomatrixserverlib.HeaderedEvent{
|
|
||||||
rootToR1.EventID(): rootToR1,
|
|
||||||
rootToR2.EventID(): rootToR2,
|
|
||||||
rootToS1.EventID(): rootToS1,
|
|
||||||
s1ToR3.EventID(): s1ToR3,
|
|
||||||
s1ToR4.EventID(): s1ToR4,
|
|
||||||
s1ToS2.EventID(): s1ToS2,
|
|
||||||
s2ToR5.EventID(): s2ToR5,
|
|
||||||
r4HisVis.EventID(): r4HisVis,
|
|
||||||
},
|
|
||||||
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
|
|
||||||
rootSpace: {
|
|
||||||
roomNameTuple: "Root",
|
|
||||||
hisVisTuple: "shared",
|
|
||||||
},
|
|
||||||
subSpaceS1: {
|
|
||||||
roomNameTuple: "Sub-Space 1",
|
|
||||||
hisVisTuple: "joined",
|
|
||||||
},
|
|
||||||
subSpaceS2: {
|
|
||||||
roomNameTuple: "Sub-Space 2",
|
|
||||||
hisVisTuple: "shared",
|
|
||||||
},
|
|
||||||
room1: {
|
|
||||||
hisVisTuple: "joined",
|
|
||||||
},
|
|
||||||
room2: {
|
|
||||||
hisVisTuple: "joined",
|
|
||||||
},
|
|
||||||
room3: {
|
|
||||||
hisVisTuple: "joined",
|
|
||||||
},
|
|
||||||
room4: {
|
|
||||||
hisVisTuple: "world_readable",
|
|
||||||
},
|
|
||||||
room5: {
|
|
||||||
hisVisTuple: "joined",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
allEvents := []*gomatrixserverlib.HeaderedEvent{
|
|
||||||
rootToR1, rootToR2, rootToS1,
|
|
||||||
s1ToR3, s1ToR4, s1ToS2,
|
|
||||||
s2ToR5, r4HisVis,
|
|
||||||
}
|
|
||||||
allEvents = append(allEvents, joinEvents...)
|
|
||||||
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
|
|
||||||
cancel := runServer(t, router)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
t.Run("returns no events for unknown rooms", func(t *testing.T) {
|
|
||||||
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
|
|
||||||
if len(res.Events) > 0 {
|
|
||||||
t.Errorf("got %d events, want 0", len(res.Events))
|
|
||||||
}
|
|
||||||
if len(res.Rooms) > 0 {
|
|
||||||
t.Errorf("got %d rooms, want 0", len(res.Rooms))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("returns the entire graph", func(t *testing.T) {
|
|
||||||
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
|
|
||||||
if len(res.Events) != 7 {
|
|
||||||
t.Errorf("got %d events, want 7", len(res.Events))
|
|
||||||
}
|
|
||||||
if len(res.Rooms) != len(allRooms) {
|
|
||||||
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("can update the graph", func(t *testing.T) {
|
|
||||||
// remove R3 from the graph
|
|
||||||
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
|
|
||||||
RoomID: subSpaceS1,
|
|
||||||
Sender: alice,
|
|
||||||
Type: msc2946.ConstSpaceChildEventType,
|
|
||||||
StateKey: &room3,
|
|
||||||
Content: map[string]interface{}{}, // redacted
|
|
||||||
})
|
|
||||||
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
|
|
||||||
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
|
|
||||||
|
|
||||||
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
|
|
||||||
if len(res.Events) != 6 { // one less since we don't return redacted events
|
|
||||||
t.Errorf("got %d events, want 6", len(res.Events))
|
|
||||||
}
|
|
||||||
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
|
|
||||||
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
|
|
||||||
t.Helper()
|
|
||||||
b, err := json.Marshal(jsonBody)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to marshal request: %s", err)
|
|
||||||
}
|
|
||||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
|
||||||
if err := json.Unmarshal(b, &r); err != nil {
|
|
||||||
t.Fatalf("Failed to unmarshal request: %s", err)
|
|
||||||
}
|
|
||||||
return &r
|
|
||||||
}
|
|
||||||
|
|
||||||
func runServer(t *testing.T, router *mux.Router) func() {
|
|
||||||
t.Helper()
|
|
||||||
externalServ := &http.Server{
|
|
||||||
Addr: string(":8010"),
|
|
||||||
WriteTimeout: 60 * time.Second,
|
|
||||||
Handler: router,
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
externalServ.ListenAndServe()
|
|
||||||
}()
|
|
||||||
// wait to listen on the port
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
return func() {
|
|
||||||
externalServ.Shutdown(context.TODO())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
|
|
||||||
t.Helper()
|
|
||||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
|
||||||
msc2946.Defaults(&r)
|
|
||||||
data, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to marshal request: %s", err)
|
|
||||||
}
|
|
||||||
httpReq, err := http.NewRequest(
|
|
||||||
"POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
|
|
||||||
bytes.NewBuffer(data),
|
|
||||||
)
|
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to prepare request: %s", err)
|
|
||||||
}
|
|
||||||
res, err := client.Do(httpReq)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to do request: %s", err)
|
|
||||||
}
|
|
||||||
if res.StatusCode != expectCode {
|
|
||||||
body, _ := ioutil.ReadAll(res.Body)
|
|
||||||
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
|
|
||||||
}
|
|
||||||
if res.StatusCode == 200 {
|
|
||||||
var result gomatrixserverlib.MSC2946SpacesResponse
|
|
||||||
body, err := ioutil.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("response 200 OK but failed to read response body: %s", err)
|
|
||||||
}
|
|
||||||
t.Logf("Body: %s", string(body))
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
|
|
||||||
}
|
|
||||||
return &result
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserAPI struct {
|
|
||||||
userapi.UserInternalAPITrace
|
|
||||||
accessTokens map[string]userapi.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
|
|
||||||
dev, ok := u.accessTokens[req.AccessToken]
|
|
||||||
if !ok {
|
|
||||||
res.Err = "unknown token"
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
res.Device = &dev
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testRoomserverAPI struct {
|
|
||||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
|
||||||
// We'll override the functions we care about.
|
|
||||||
roomserver.RoomserverInternalAPITrace
|
|
||||||
joinEvents []*gomatrixserverlib.HeaderedEvent
|
|
||||||
events map[string]*gomatrixserverlib.HeaderedEvent
|
|
||||||
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
|
|
||||||
res.IsInRoom = true
|
|
||||||
res.RoomExists = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
|
|
||||||
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
|
||||||
for _, roomID := range req.RoomIDs {
|
|
||||||
pubRoomData, ok := r.pubRoomState[roomID]
|
|
||||||
if ok {
|
|
||||||
res.Rooms[roomID] = pubRoomData
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
|
|
||||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
|
||||||
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
|
|
||||||
if he.RoomID() != req.RoomID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if he.StateKey() == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tuple := gomatrixserverlib.StateKeyTuple{
|
|
||||||
EventType: he.Type(),
|
|
||||||
StateKey: *he.StateKey(),
|
|
||||||
}
|
|
||||||
for _, t := range req.StateTuples {
|
|
||||||
if t == tuple {
|
|
||||||
res.StateEvents[t] = he
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, he := range r.joinEvents {
|
|
||||||
checkEvent(he)
|
|
||||||
}
|
|
||||||
for _, he := range r.events {
|
|
||||||
checkEvent(he)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
|
|
||||||
t.Helper()
|
|
||||||
cfg := &config.Dendrite{}
|
|
||||||
cfg.Defaults(true)
|
|
||||||
cfg.Global.ServerName = "localhost"
|
|
||||||
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
|
|
||||||
cfg.MSCs.MSCs = []string{"msc2946"}
|
|
||||||
base := &base.BaseDendrite{
|
|
||||||
Cfg: cfg,
|
|
||||||
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
|
|
||||||
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to enable MSC2946: %s", err)
|
|
||||||
}
|
|
||||||
for _, ev := range events {
|
|
||||||
hooks.Run(hooks.KindNewEventPersisted, ev)
|
|
||||||
}
|
|
||||||
return base.PublicClientAPIMux
|
|
||||||
}
|
|
||||||
|
|
||||||
type fledglingEvent struct {
|
|
||||||
Type string
|
|
||||||
StateKey *string
|
|
||||||
Content interface{}
|
|
||||||
Sender string
|
|
||||||
RoomID string
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
|
|
||||||
t.Helper()
|
|
||||||
seed := make([]byte, ed25519.SeedSize) // zero seed
|
|
||||||
key := ed25519.NewKeyFromSeed(seed)
|
|
||||||
eb := gomatrixserverlib.EventBuilder{
|
|
||||||
Sender: ev.Sender,
|
|
||||||
Depth: 999,
|
|
||||||
Type: ev.Type,
|
|
||||||
StateKey: ev.StateKey,
|
|
||||||
RoomID: ev.RoomID,
|
|
||||||
}
|
|
||||||
err := eb.SetContent(ev.Content)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
|
|
||||||
}
|
|
||||||
// make sure the origin_server_ts changes so we can test recency
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
|
|
||||||
}
|
|
||||||
h := signedEvent.Headered(roomVer)
|
|
||||||
return h
|
|
||||||
}
|
|
|
@ -1,182 +0,0 @@
|
||||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package msc2946
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
relTypes = map[string]int{
|
|
||||||
ConstSpaceChildEventType: 1,
|
|
||||||
ConstSpaceParentEventType: 2,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type Database interface {
|
|
||||||
// StoreReference persists a child or parent space mapping.
|
|
||||||
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
|
|
||||||
// References returns all events which have the given roomID as a parent or child space.
|
|
||||||
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type DB struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
insertEdgeStmt *sql.Stmt
|
|
||||||
selectEdgesStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDatabase loads the database for msc2836
|
|
||||||
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
|
||||||
if dbOpts.ConnectionString.IsPostgres() {
|
|
||||||
return newPostgresDatabase(dbOpts)
|
|
||||||
}
|
|
||||||
return newSQLiteDatabase(dbOpts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
|
||||||
d := DB{
|
|
||||||
writer: sqlutil.NewDummyWriter(),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_, err = d.db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
|
||||||
room_version TEXT NOT NULL,
|
|
||||||
-- the room ID of the event, the source of the arrow
|
|
||||||
source_room_id TEXT NOT NULL,
|
|
||||||
-- the target room ID, the arrow destination
|
|
||||||
dest_room_id TEXT NOT NULL,
|
|
||||||
-- the kind of relation, either child or parent (1,2)
|
|
||||||
rel_type SMALLINT NOT NULL,
|
|
||||||
event_json TEXT NOT NULL,
|
|
||||||
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
|
|
||||||
);
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
|
||||||
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
|
||||||
VALUES($1, $2, $3, $4, $5)
|
|
||||||
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if d.selectEdgesStmt, err = d.db.Prepare(`
|
|
||||||
SELECT room_version, event_json FROM msc2946_edges
|
|
||||||
WHERE source_room_id = $1 OR dest_room_id = $2
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &d, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
|
||||||
d := DB{
|
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_, err = d.db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
|
||||||
room_version TEXT NOT NULL,
|
|
||||||
-- the room ID of the event, the source of the arrow
|
|
||||||
source_room_id TEXT NOT NULL,
|
|
||||||
-- the target room ID, the arrow destination
|
|
||||||
dest_room_id TEXT NOT NULL,
|
|
||||||
-- the kind of relation, either child or parent (1,2)
|
|
||||||
rel_type SMALLINT NOT NULL,
|
|
||||||
event_json TEXT NOT NULL,
|
|
||||||
UNIQUE (source_room_id, dest_room_id, rel_type)
|
|
||||||
);
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
|
||||||
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
|
||||||
VALUES($1, $2, $3, $4, $5)
|
|
||||||
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if d.selectEdgesStmt, err = d.db.Prepare(`
|
|
||||||
SELECT room_version, event_json FROM msc2946_edges
|
|
||||||
WHERE source_room_id = $1 OR dest_room_id = $2
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &d, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
|
|
||||||
target := SpaceTarget(he)
|
|
||||||
if target == "" {
|
|
||||||
return nil // malformed event
|
|
||||||
}
|
|
||||||
relType := relTypes[he.Type()]
|
|
||||||
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
|
||||||
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
|
|
||||||
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var roomVer string
|
|
||||||
var jsonBytes []byte
|
|
||||||
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
|
|
||||||
refs = append(refs, he)
|
|
||||||
}
|
|
||||||
return refs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
|
|
||||||
// depending on the event type.
|
|
||||||
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
|
|
||||||
if he.StateKey() == nil {
|
|
||||||
return "" // no-op
|
|
||||||
}
|
|
||||||
switch he.Type() {
|
|
||||||
case ConstSpaceParentEventType:
|
|
||||||
return *he.StateKey()
|
|
||||||
case ConstSpaceChildEventType:
|
|
||||||
return *he.StateKey()
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
|
@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er
|
||||||
case "msc2836":
|
case "msc2836":
|
||||||
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
|
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
|
||||||
case "msc2946":
|
case "msc2946":
|
||||||
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing)
|
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches)
|
||||||
case "msc2444": // enabled inside federationapi
|
case "msc2444": // enabled inside federationapi
|
||||||
case "msc2753": // enabled inside clientapi
|
case "msc2753": // enabled inside clientapi
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -16,7 +16,9 @@ package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
@ -24,21 +26,26 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OutputClientDataConsumer consumes events that originated in the client API server.
|
// OutputClientDataConsumer consumes events that originated in the client API server.
|
||||||
type OutputClientDataConsumer struct {
|
type OutputClientDataConsumer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
jetstream nats.JetStreamContext
|
jetstream nats.JetStreamContext
|
||||||
durable string
|
durable string
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
stream types.StreamProvider
|
stream types.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
producer *producers.UserAPIReadProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
|
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
|
||||||
|
@ -49,15 +56,18 @@ func NewOutputClientDataConsumer(
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream types.StreamProvider,
|
||||||
|
producer *producers.UserAPIReadProducer,
|
||||||
) *OutputClientDataConsumer {
|
) *OutputClientDataConsumer {
|
||||||
return &OutputClientDataConsumer{
|
return &OutputClientDataConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
jetstream: js,
|
jetstream: js,
|
||||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
|
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
|
||||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
|
durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
|
||||||
db: store,
|
db: store,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
serverName: cfg.Matrix.ServerName,
|
||||||
|
producer: producer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,8 +110,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
||||||
}).Panicf("could not save account data")
|
}).Panicf("could not save account data")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
|
||||||
|
log.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"user_id": userID,
|
||||||
|
"room_id": output.RoomID,
|
||||||
|
}).Errorf("Failed to generate read update")
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
s.stream.Advance(streamPos)
|
s.stream.Advance(streamPos)
|
||||||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
|
||||||
|
if output.Type != "m.fully_read" || output.ReadMarker == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
|
}
|
||||||
|
if serverName != s.serverName {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var readPos types.StreamPosition
|
||||||
|
var fullyReadPos types.StreamPosition
|
||||||
|
if output.ReadMarker.Read != "" {
|
||||||
|
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if output.ReadMarker.FullyRead != "" {
|
||||||
|
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readPos > 0 || fullyReadPos > 0 {
|
||||||
|
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
|
||||||
|
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -16,7 +16,9 @@ package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/dendrite/eduserver/api"
|
"github.com/matrix-org/dendrite/eduserver/api"
|
||||||
|
@ -24,21 +26,26 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OutputReceiptEventConsumer consumes events that originated in the EDU server.
|
// OutputReceiptEventConsumer consumes events that originated in the EDU server.
|
||||||
type OutputReceiptEventConsumer struct {
|
type OutputReceiptEventConsumer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
jetstream nats.JetStreamContext
|
jetstream nats.JetStreamContext
|
||||||
durable string
|
durable string
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
stream types.StreamProvider
|
stream types.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
producer *producers.UserAPIReadProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
|
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
|
||||||
|
@ -50,15 +57,18 @@ func NewOutputReceiptEventConsumer(
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream types.StreamProvider,
|
||||||
|
producer *producers.UserAPIReadProducer,
|
||||||
) *OutputReceiptEventConsumer {
|
) *OutputReceiptEventConsumer {
|
||||||
return &OutputReceiptEventConsumer{
|
return &OutputReceiptEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
jetstream: js,
|
jetstream: js,
|
||||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
|
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
|
||||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
|
durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
|
||||||
db: store,
|
db: store,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
serverName: cfg.Matrix.ServerName,
|
||||||
|
producer: producer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,8 +102,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = s.sendReadUpdate(ctx, output); err != nil {
|
||||||
|
log.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"user_id": output.UserID,
|
||||||
|
"room_id": output.RoomID,
|
||||||
|
}).Errorf("Failed to generate read update")
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
s.stream.Advance(streamPos)
|
s.stream.Advance(streamPos)
|
||||||
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error {
|
||||||
|
if output.Type != "m.read" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
|
}
|
||||||
|
if serverName != s.serverName {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var readPos types.StreamPosition
|
||||||
|
if output.EventID != "" {
|
||||||
|
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readPos > 0 {
|
||||||
|
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
|
||||||
|
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/producers"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct {
|
||||||
pduStream types.StreamProvider
|
pduStream types.StreamProvider
|
||||||
inviteStream types.StreamProvider
|
inviteStream types.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
|
producer *producers.UserAPIStreamEventProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
||||||
|
@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer(
|
||||||
pduStream types.StreamProvider,
|
pduStream types.StreamProvider,
|
||||||
inviteStream types.StreamProvider,
|
inviteStream types.StreamProvider,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.RoomserverInternalAPI,
|
||||||
|
producer *producers.UserAPIStreamEventProducer,
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
return &OutputRoomEventConsumer{
|
return &OutputRoomEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer(
|
||||||
pduStream: pduStream,
|
pduStream: pduStream,
|
||||||
inviteStream: inviteStream,
|
inviteStream: inviteStream,
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
|
producer: producer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
|
||||||
|
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
|
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
|
||||||
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
|
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
|
|
110
syncapi/consumers/userapi.go
Normal file
110
syncapi/consumers/userapi.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package consumers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OutputNotificationDataConsumer consumes events that originated in
|
||||||
|
// the Push server.
|
||||||
|
type OutputNotificationDataConsumer struct {
|
||||||
|
ctx context.Context
|
||||||
|
jetstream nats.JetStreamContext
|
||||||
|
durable string
|
||||||
|
topic string
|
||||||
|
db storage.Database
|
||||||
|
notifier *notifier.Notifier
|
||||||
|
stream types.StreamProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOutputNotificationDataConsumer creates a new consumer. Call
|
||||||
|
// Start() to begin consuming.
|
||||||
|
func NewOutputNotificationDataConsumer(
|
||||||
|
process *process.ProcessContext,
|
||||||
|
cfg *config.SyncAPI,
|
||||||
|
js nats.JetStreamContext,
|
||||||
|
store storage.Database,
|
||||||
|
notifier *notifier.Notifier,
|
||||||
|
stream types.StreamProvider,
|
||||||
|
) *OutputNotificationDataConsumer {
|
||||||
|
s := &OutputNotificationDataConsumer{
|
||||||
|
ctx: process.Context(),
|
||||||
|
jetstream: js,
|
||||||
|
durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"),
|
||||||
|
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
|
||||||
|
db: store,
|
||||||
|
notifier: notifier,
|
||||||
|
stream: stream,
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts consumption.
|
||||||
|
func (s *OutputNotificationDataConsumer) Start() error {
|
||||||
|
return jetstream.JetStreamConsumer(
|
||||||
|
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||||
|
nats.DeliverAll(), nats.ManualAck(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onMessage is called when the Sync server receives a new event from
|
||||||
|
// the push server. It is not safe for this function to be called from
|
||||||
|
// multiple goroutines, or else the sync stream position may race and
|
||||||
|
// be incorrectly calculated.
|
||||||
|
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||||
|
userID := string(msg.Header.Get(jetstream.UserID))
|
||||||
|
|
||||||
|
// Parse out the event JSON
|
||||||
|
var data eventutil.NotificationData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount)
|
||||||
|
if err != nil {
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"user_id": userID,
|
||||||
|
"room_id": data.RoomID,
|
||||||
|
}).WithError(err).Error("Could not save notification counts")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stream.Advance(streamPos)
|
||||||
|
s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos})
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"user_id": userID,
|
||||||
|
"room_id": data.RoomID,
|
||||||
|
"streamPos": streamPos,
|
||||||
|
}).Trace("Received notification data from user API")
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
|
@ -82,7 +82,16 @@ func DeviceListCatchup(
|
||||||
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
||||||
return to, hasNew, nil
|
return to, hasNew, nil
|
||||||
}
|
}
|
||||||
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
|
|
||||||
|
// Work out which user IDs we care about — that includes those in the original request,
|
||||||
|
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
|
||||||
|
// as well as every user who has a join or leave event in the current sync response. We
|
||||||
|
// will request information about which rooms these users are joined to, so that we can
|
||||||
|
// see if we still share any rooms with them.
|
||||||
|
joinUserIDs, leaveUserIDs := membershipEvents(res)
|
||||||
|
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
|
||||||
|
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
||||||
|
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
||||||
var sharedUsersMap map[string]int
|
var sharedUsersMap map[string]int
|
||||||
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
|
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
|
||||||
util.GetLogger(ctx).Debugf(
|
util.GetLogger(ctx).Debugf(
|
||||||
|
@ -100,9 +109,8 @@ func DeviceListCatchup(
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if the response has any join/leave events, add them now.
|
// Finally, add in users who have joined or left.
|
||||||
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||||
joinUserIDs, leaveUserIDs := membershipEvents(res)
|
|
||||||
for _, userID := range joinUserIDs {
|
for _, userID := range joinUserIDs {
|
||||||
if !userSet[userID] {
|
if !userSet[userID] {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
|
@ -213,7 +221,8 @@ func filterSharedUsers(
|
||||||
var result []string
|
var result []string
|
||||||
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
|
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
|
||||||
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
|
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
OtherUserIDs: usersWithChangedKeys,
|
||||||
}, &sharedUsersRes)
|
}, &sharedUsersRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// default to all users so we do needless queries rather than miss some important device update
|
// default to all users so we do needless queries rather than miss some important device update
|
||||||
|
|
|
@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite(
|
||||||
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewNotificationData(
|
||||||
|
userID string,
|
||||||
|
posUpdate types.StreamingToken,
|
||||||
|
) {
|
||||||
|
n.streamLock.Lock()
|
||||||
|
defer n.streamLock.Unlock()
|
||||||
|
|
||||||
|
n.currPos.ApplyUpdates(posUpdate)
|
||||||
|
n.wakeupUsers([]string{userID}, nil, n.currPos)
|
||||||
|
}
|
||||||
|
|
||||||
// GetListener returns a UserStreamListener that can be used to wait for
|
// GetListener returns a UserStreamListener that can be used to wait for
|
||||||
// updates for a user. Must be closed.
|
// updates for a user. Must be closed.
|
||||||
// notify for anything before sincePos
|
// notify for anything before sincePos
|
||||||
|
|
|
@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) {
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
t.Errorf("TestNewInviteEventForUser error: %v", err)
|
||||||
}
|
}
|
||||||
mustEqualPositions(t, pos, syncPositionNewEDU)
|
mustEqualPositions(t, pos, syncPositionNewEDU)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
|
|
62
syncapi/producers/userapi_readupdate.go
Normal file
62
syncapi/producers/userapi_readupdate.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package producers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAPIProducer produces events for the user API server to consume
|
||||||
|
type UserAPIReadProducer struct {
|
||||||
|
Topic string
|
||||||
|
JetStream nats.JetStreamContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendData sends account data to the user API server
|
||||||
|
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
|
||||||
|
m := &nats.Msg{
|
||||||
|
Subject: p.Topic,
|
||||||
|
Header: nats.Header{},
|
||||||
|
}
|
||||||
|
m.Header.Set(jetstream.UserID, userID)
|
||||||
|
m.Header.Set(jetstream.RoomID, roomID)
|
||||||
|
|
||||||
|
data := types.ReadUpdate{
|
||||||
|
UserID: userID,
|
||||||
|
RoomID: roomID,
|
||||||
|
Read: readPos,
|
||||||
|
FullyRead: fullyReadPos,
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
m.Data, err = json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"user_id": userID,
|
||||||
|
"room_id": roomID,
|
||||||
|
"read_pos": readPos,
|
||||||
|
"fully_read_pos": fullyReadPos,
|
||||||
|
}).Tracef("Producing to topic '%s'", p.Topic)
|
||||||
|
|
||||||
|
_, err = p.JetStream.PublishMsg(m)
|
||||||
|
return err
|
||||||
|
}
|
60
syncapi/producers/userapi_streamevent.go
Normal file
60
syncapi/producers/userapi_streamevent.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package producers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAPIProducer produces events for the user API server to consume
|
||||||
|
type UserAPIStreamEventProducer struct {
|
||||||
|
Topic string
|
||||||
|
JetStream nats.JetStreamContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendData sends account data to the user API server
|
||||||
|
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
|
||||||
|
m := &nats.Msg{
|
||||||
|
Subject: p.Topic,
|
||||||
|
Header: nats.Header{},
|
||||||
|
}
|
||||||
|
m.Header.Set(jetstream.RoomID, roomID)
|
||||||
|
|
||||||
|
data := types.StreamedEvent{
|
||||||
|
Event: event,
|
||||||
|
StreamPosition: pos,
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
m.Data, err = json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"room_id": roomID,
|
||||||
|
"event_id": event.EventID(),
|
||||||
|
"event_type": event.Type(),
|
||||||
|
"stream_pos": pos,
|
||||||
|
}).Tracef("Producing to topic '%s'", p.Topic)
|
||||||
|
|
||||||
|
_, err = p.JetStream.PublishMsg(m)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ func Context(
|
||||||
syncDB storage.Database,
|
syncDB storage.Database,
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
filter, err := parseContextParams(req)
|
filter, err := parseRoomEventFilter(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := ""
|
errMsg := ""
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
|
@ -102,6 +103,12 @@ func Context(
|
||||||
|
|
||||||
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
|
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event")
|
logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -164,7 +171,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter
|
||||||
return newState
|
return newState
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
|
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
|
||||||
// Default room filter
|
// Default room filter
|
||||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 10}
|
filter := &gomatrixserverlib.RoomEventFilter{Limit: 10}
|
||||||
|
|
||||||
|
|
|
@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gotFilter, err := parseContextParams(tt.req)
|
gotFilter, err := parseRoomEventFilter(tt.req)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(gotFilter, tt.wantFilter) {
|
if !reflect.DeepEqual(gotFilter, tt.wantFilter) {
|
||||||
t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
|
t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -45,8 +44,8 @@ type messagesReq struct {
|
||||||
fromStream *types.StreamingToken
|
fromStream *types.StreamingToken
|
||||||
device *userapi.Device
|
device *userapi.Device
|
||||||
wasToProvided bool
|
wasToProvided bool
|
||||||
limit int
|
|
||||||
backwardOrdering bool
|
backwardOrdering bool
|
||||||
|
filter *gomatrixserverlib.RoomEventFilter
|
||||||
}
|
}
|
||||||
|
|
||||||
type messagesResp struct {
|
type messagesResp struct {
|
||||||
|
@ -54,10 +53,9 @@ type messagesResp struct {
|
||||||
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
|
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
|
||||||
End string `json:"end"`
|
End string `json:"end"`
|
||||||
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
|
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
|
||||||
|
State []gomatrixserverlib.ClientEvent `json:"state"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultMessagesLimit = 10
|
|
||||||
|
|
||||||
// OnIncomingMessagesRequest implements the /messages endpoint from the
|
// OnIncomingMessagesRequest implements the /messages endpoint from the
|
||||||
// client-server API.
|
// client-server API.
|
||||||
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
|
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
|
||||||
|
@ -83,6 +81,14 @@ func OnIncomingMessagesRequest(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filter, err := parseRoomEventFilter(req)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("unable to parse filter"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Extract parameters from the request's URL.
|
// Extract parameters from the request's URL.
|
||||||
// Pagination tokens.
|
// Pagination tokens.
|
||||||
var fromStream *types.StreamingToken
|
var fromStream *types.StreamingToken
|
||||||
|
@ -143,18 +149,6 @@ func OnIncomingMessagesRequest(
|
||||||
wasToProvided = false
|
wasToProvided = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Maximum number of events to return; defaults to 10.
|
|
||||||
limit := defaultMessagesLimit
|
|
||||||
if len(req.URL.Query().Get("limit")) > 0 {
|
|
||||||
limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO: Implement filtering (#587)
|
// TODO: Implement filtering (#587)
|
||||||
|
|
||||||
// Check the room ID's format.
|
// Check the room ID's format.
|
||||||
|
@ -176,7 +170,7 @@ func OnIncomingMessagesRequest(
|
||||||
to: &to,
|
to: &to,
|
||||||
fromStream: fromStream,
|
fromStream: fromStream,
|
||||||
wasToProvided: wasToProvided,
|
wasToProvided: wasToProvided,
|
||||||
limit: limit,
|
filter: filter,
|
||||||
backwardOrdering: backwardOrdering,
|
backwardOrdering: backwardOrdering,
|
||||||
device: device,
|
device: device,
|
||||||
}
|
}
|
||||||
|
@ -187,10 +181,27 @@ func OnIncomingMessagesRequest(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
|
||||||
|
state := []gomatrixserverlib.ClientEvent{}
|
||||||
|
if filter.LazyLoadMembers {
|
||||||
|
memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
for _, evt := range clientEvents {
|
||||||
|
memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
memberShipToUser[evt.Sender] = memberShip
|
||||||
|
}
|
||||||
|
for _, evt := range memberShipToUser {
|
||||||
|
state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||||
"from": from.String(),
|
"from": from.String(),
|
||||||
"to": to.String(),
|
"to": to.String(),
|
||||||
"limit": limit,
|
"limit": filter.Limit,
|
||||||
"backwards": backwardOrdering,
|
"backwards": backwardOrdering,
|
||||||
"return_start": start.String(),
|
"return_start": start.String(),
|
||||||
"return_end": end.String(),
|
"return_end": end.String(),
|
||||||
|
@ -200,6 +211,7 @@ func OnIncomingMessagesRequest(
|
||||||
Chunk: clientEvents,
|
Chunk: clientEvents,
|
||||||
Start: start.String(),
|
Start: start.String(),
|
||||||
End: end.String(),
|
End: end.String(),
|
||||||
|
State: state,
|
||||||
}
|
}
|
||||||
if emptyFromSupplied {
|
if emptyFromSupplied {
|
||||||
res.StartStream = fromStream.String()
|
res.StartStream = fromStream.String()
|
||||||
|
@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||||
end types.TopologyToken, err error,
|
end types.TopologyToken, err error,
|
||||||
) {
|
) {
|
||||||
eventFilter := gomatrixserverlib.DefaultRoomEventFilter()
|
eventFilter := r.filter
|
||||||
eventFilter.Limit = r.limit
|
|
||||||
|
|
||||||
// Retrieve the events from the local database.
|
// Retrieve the events from the local database.
|
||||||
var streamEvents []types.StreamEvent
|
var streamEvents []types.StreamEvent
|
||||||
if r.fromStream != nil {
|
if r.fromStream != nil {
|
||||||
toStream := r.to.StreamToken()
|
toStream := r.to.StreamToken()
|
||||||
streamEvents, err = r.db.GetEventsInStreamingRange(
|
streamEvents, err = r.db.GetEventsInStreamingRange(
|
||||||
r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering,
|
r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
streamEvents, err = r.db.GetEventsInTopologicalRange(
|
streamEvents, err = r.db.GetEventsInTopologicalRange(
|
||||||
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
|
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
|
||||||
// Check if we have backward extremities for this room.
|
// Check if we have backward extremities for this room.
|
||||||
if len(backwardExtremities) > 0 {
|
if len(backwardExtremities) > 0 {
|
||||||
// If so, retrieve as much events as needed through backfilling.
|
// If so, retrieve as much events as needed through backfilling.
|
||||||
events, err = r.backfill(r.roomID, backwardExtremities, r.limit)
|
events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
||||||
events []*gomatrixserverlib.HeaderedEvent, err error,
|
events []*gomatrixserverlib.HeaderedEvent, err error,
|
||||||
) {
|
) {
|
||||||
// Check if we have enough events.
|
// Check if we have enough events.
|
||||||
isSetLargeEnough := len(streamEvents) >= r.limit
|
isSetLargeEnough := len(streamEvents) >= r.filter.Limit
|
||||||
if !isSetLargeEnough {
|
if !isSetLargeEnough {
|
||||||
// it might be fine we don't have up to 'limit' events, let's find out
|
// it might be fine we don't have up to 'limit' events, let's find out
|
||||||
if r.backwardOrdering {
|
if r.backwardOrdering {
|
||||||
|
@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
||||||
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
|
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
|
||||||
var pdus []*gomatrixserverlib.HeaderedEvent
|
var pdus []*gomatrixserverlib.HeaderedEvent
|
||||||
// Only ask the remote server for enough events to reach the limit.
|
// Only ask the remote server for enough events to reach the limit.
|
||||||
pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents))
|
pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
|
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
@ -31,6 +32,7 @@ type Database interface {
|
||||||
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
|
||||||
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||||
|
@ -138,6 +140,12 @@ type Database interface {
|
||||||
// GetRoomReceipts gets all receipts for a given roomID
|
// GetRoomReceipts gets all receipts for a given roomID
|
||||||
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
|
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
|
||||||
|
|
||||||
|
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
||||||
|
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||||
|
|
||||||
|
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
|
||||||
|
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
|
||||||
|
|
||||||
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
||||||
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
|
108
syncapi/storage/postgres/notification_data_table.go
Normal file
108
syncapi/storage/postgres/notification_data_table.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
|
||||||
|
_, err := db.Exec(notificationDataSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r := ¬ificationDataStatements{}
|
||||||
|
return r, sqlutil.StatementList{
|
||||||
|
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
|
||||||
|
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
|
||||||
|
{&r.selectMaxID, selectMaxNotificationIDSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
type notificationDataStatements struct {
|
||||||
|
upsertRoomUnreadCounts *sql.Stmt
|
||||||
|
selectUserUnreadCounts *sql.Stmt
|
||||||
|
selectMaxID *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
const notificationDataSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
notification_count BIGINT NOT NULL DEFAULT 0,
|
||||||
|
highlight_count BIGINT NOT NULL DEFAULT 0,
|
||||||
|
CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id)
|
||||||
|
);`
|
||||||
|
|
||||||
|
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
|
||||||
|
(user_id, room_id, notification_count, highlight_count)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (user_id, room_id)
|
||||||
|
DO UPDATE SET notification_count = $3, highlight_count = $4
|
||||||
|
RETURNING id`
|
||||||
|
|
||||||
|
const selectUserUnreadNotificationCountsSQL = `SELECT
|
||||||
|
id, room_id, notification_count, highlight_count
|
||||||
|
FROM syncapi_notification_data
|
||||||
|
WHERE
|
||||||
|
user_id = $1 AND
|
||||||
|
id BETWEEN $2 + 1 AND $3`
|
||||||
|
|
||||||
|
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
|
||||||
|
|
||||||
|
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
|
||||||
|
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||||
|
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
|
||||||
|
|
||||||
|
roomCounts := map[string]*eventutil.NotificationData{}
|
||||||
|
for rows.Next() {
|
||||||
|
var id types.StreamPosition
|
||||||
|
var roomID string
|
||||||
|
var notificationCount, highlightCount int
|
||||||
|
|
||||||
|
if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
roomCounts[roomID] = &eventutil.NotificationData{
|
||||||
|
RoomID: roomID,
|
||||||
|
UnreadNotificationCount: notificationCount,
|
||||||
|
UnreadHighlightCount: highlightCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return roomCounts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
|
||||||
|
return id, err
|
||||||
|
}
|
|
@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
|
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
|
||||||
lastPos := streamPos
|
var lastPos types.StreamPosition
|
||||||
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
|
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
|
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue