mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
git subtree pull dendrite
This commit is contained in:
commit
4bd22ff88b
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
|
@ -63,7 +63,7 @@ jobs:
|
|||
# Run Complement
|
||||
- run: |
|
||||
set -o pipefail &&
|
||||
go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
||||
go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
||||
shell: bash
|
||||
name: Run Complement Tests
|
||||
env:
|
||||
|
|
|
|||
8
.gitignore
vendored
8
.gitignore
vendored
|
|
@ -54,7 +54,7 @@ dendrite.yaml
|
|||
*.db
|
||||
|
||||
# Log files
|
||||
*.log*
|
||||
*.log*
|
||||
|
||||
# Generated code
|
||||
cmd/dendrite-demo-yggdrasil/embed/fs*.go
|
||||
|
|
@ -62,7 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
|
|||
# Test dependencies
|
||||
test/wasm/node_modules
|
||||
|
||||
media_store/
|
||||
# Ignore complement folder when running locally
|
||||
complement/
|
||||
|
||||
# Debug
|
||||
**/__debug_bin
|
||||
media_store/
|
||||
|
|
|
|||
34
CHANGES.md
34
CHANGES.md
|
|
@ -1,5 +1,31 @@
|
|||
# Changelog
|
||||
|
||||
## Dendrite 0.6.5 (2022-03-04)
|
||||
|
||||
### Features
|
||||
|
||||
* Early support for push notifications has been added, with support for push rules, pushers, HTTP push gateways and the `/notifications` endpoint (contributions by [danpe](https://github.com/danpe), [PiotrKozimor](https://github.com/PiotrKozimor) and [tommie](https://github.com/tommie))
|
||||
* Spaces Summary (MSC2946) is now correctly supported (when `msc2946` is enabled in the config)
|
||||
* All media API endpoints are now available under the `/v3` namespace
|
||||
* Profile updates (display name and avatar) are now sent asynchronously so they shouldn't block the client for a very long time
|
||||
* State resolution v2 has been optimised further to considerably reduce the number of memory allocations
|
||||
* State resolution v2 will no longer duplicate events unnecessarily when calculating the auth difference
|
||||
* The `create-account` tool now has a `-reset-password` option for resetting the passwords of existing accounts
|
||||
* The `/sync` endpoint now calculates device list changes much more quickly with less RAM used
|
||||
* The `/messages` endpoint now lazy-loads members correctly
|
||||
|
||||
### Fixes
|
||||
|
||||
* Read receipts now work correctly by correcting bugs in the stream positions and receipt coalescing
|
||||
* Topological sorting of state and join responses has been corrected, which should help to reduce the number of auth problems when joining new federated rooms
|
||||
* Media thumbnails should now work properly after having unnecessarily strict rate limiting removed
|
||||
* The roomserver no longer holds transactions for as long when processing input events
|
||||
* Uploading device keys and cross-signing keys will now correctly no-op if there were no changes
|
||||
* Parameters are now remembered correctly during registration
|
||||
* Devices can now only be deleted within the appropriate UIA flow
|
||||
* The `/context` endpoint now returns 404 instead of 500 if the event was not found
|
||||
* SQLite mode will no longer leak memory as a result of not closing prepared statements
|
||||
|
||||
## Dendrite 0.6.4 (2022-02-21)
|
||||
|
||||
### Features
|
||||
|
|
@ -210,9 +236,9 @@
|
|||
|
||||
### Fixes
|
||||
|
||||
- **SECURITY:** A bug in SQLite mode which could cause the registration flow to complete unexpectedly for existing accounts has been fixed (PostgreSQL deployments are not affected)
|
||||
- A panic in the federation sender has been fixed when shutting down destination queues
|
||||
- The `/keys/upload` endpoint now correctly returns the number of one-time keys in response to an empty upload request
|
||||
* **SECURITY:** A bug in SQLite mode which could cause the registration flow to complete unexpectedly for existing accounts has been fixed (PostgreSQL deployments are not affected)
|
||||
* A panic in the federation sender has been fixed when shutting down destination queues
|
||||
* The `/keys/upload` endpoint now correctly returns the number of one-time keys in response to an empty upload request
|
||||
|
||||
## Dendrite 0.3.10 (2021-02-17)
|
||||
|
||||
|
|
@ -534,4 +560,4 @@ First versioned release of Dendrite.
|
|||
* Typing: Yes.
|
||||
* Presence: No.
|
||||
* Receipts: No.
|
||||
* OpenID: No.
|
||||
* OpenID: No.
|
||||
|
|
|
|||
69
README.md
69
README.md
|
|
@ -2,26 +2,29 @@
|
|||
|
||||
Dendrite is a second-generation Matrix homeserver written in Go.
|
||||
It intends to provide an **efficient**, **reliable** and **scalable** alternative to [Synapse](https://github.com/matrix-org/synapse):
|
||||
- Efficient: A small memory footprint with better baseline performance than an out-of-the-box Synapse.
|
||||
- Reliable: Implements the Matrix specification as written, using the
|
||||
|
||||
- Efficient: A small memory footprint with better baseline performance than an out-of-the-box Synapse.
|
||||
- Reliable: Implements the Matrix specification as written, using the
|
||||
[same test suite](https://github.com/matrix-org/sytest) as Synapse as well as
|
||||
a [brand new Go test suite](https://github.com/matrix-org/complement).
|
||||
- Scalable: can run on multiple machines and eventually scale to massive homeserver deployments.
|
||||
- Scalable: can run on multiple machines and eventually scale to massive homeserver deployments.
|
||||
|
||||
As of October 2020, Dendrite has now entered **beta** which means:
|
||||
|
||||
- Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database.
|
||||
- Dendrite has periodic semver releases. We intend to release new versions as we land significant features.
|
||||
- Dendrite supports database schema upgrades between releases. This means you should never lose your messages when upgrading Dendrite.
|
||||
- Breaking changes will not occur on minor releases. This means you can safely upgrade Dendrite without modifying your database or config file.
|
||||
|
||||
This does not mean:
|
||||
- Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially.
|
||||
- All of the CS/Federation APIs are implemented. We are tracking progress via a script called 'Are We Synapse Yet?'. In particular,
|
||||
|
||||
- Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially.
|
||||
- All of the CS/Federation APIs are implemented. We are tracking progress via a script called 'Are We Synapse Yet?'. In particular,
|
||||
presence and push notifications are entirely missing from Dendrite. See [CHANGES.md](CHANGES.md) for updates.
|
||||
- Dendrite is ready for massive homeserver deployments. You cannot shard each microservice, only run each one on a different machine.
|
||||
- Dendrite is ready for massive homeserver deployments. You cannot shard each microservice, only run each one on a different machine.
|
||||
|
||||
Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices.
|
||||
In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode.
|
||||
In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode.
|
||||
|
||||
If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or join us in:
|
||||
|
||||
|
|
@ -31,14 +34,16 @@ If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or j
|
|||
|
||||
## Requirements
|
||||
|
||||
To build Dendrite, you will need Go 1.16 or later.
|
||||
To build Dendrite, you will need Go 1.16 or later.
|
||||
|
||||
For a usable federating Dendrite deployment, you will also need:
|
||||
- A domain name (or subdomain)
|
||||
|
||||
- A domain name (or subdomain)
|
||||
- A valid TLS certificate issued by a trusted authority for that domain
|
||||
- SRV records or a well-known file pointing to your deployment
|
||||
|
||||
Also recommended are:
|
||||
|
||||
- A PostgreSQL database engine, which will perform better than SQLite with many users and/or larger rooms
|
||||
- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf)
|
||||
|
||||
|
|
@ -76,30 +81,32 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l
|
|||
|
||||
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
|
||||
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
|
||||
updates with CI. As of January 2022 we're at around 65% CS API coverage and 92% Federation coverage, though check
|
||||
updates with CI. As of March 2022 we're at around 76% CS API coverage and 95% Federation coverage, though check
|
||||
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
|
||||
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
|
||||
- Push
|
||||
- Search and Context
|
||||
- User Directory
|
||||
- Presence
|
||||
- Guests
|
||||
|
||||
- Search
|
||||
- User Directory
|
||||
- Presence
|
||||
|
||||
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
|
||||
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
|
||||
This means Dendrite supports amongst others:
|
||||
- Core room functionality (creating rooms, invites, auth rules)
|
||||
- Federation in rooms v1-v7
|
||||
- Backfilling locally and via federation
|
||||
- Accounts, Profiles and Devices
|
||||
- Published room lists
|
||||
- Typing
|
||||
- Media APIs
|
||||
- Redaction
|
||||
- Tagging
|
||||
- E2E keys and device lists
|
||||
- Receipts
|
||||
|
||||
- Core room functionality (creating rooms, invites, auth rules)
|
||||
- Federation in rooms v1-v7
|
||||
- Backfilling locally and via federation
|
||||
- Accounts, Profiles and Devices
|
||||
- Published room lists
|
||||
- Typing
|
||||
- Media APIs
|
||||
- Redaction
|
||||
- Tagging
|
||||
- Context
|
||||
- E2E keys and device lists
|
||||
- Receipts
|
||||
- Push
|
||||
- Guests
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
@ -112,6 +119,7 @@ For example, if the test `Local device key changes get to remote servers` was ma
|
|||
test file (e.g via `grep` or via the
|
||||
[CI log output](https://buildkite.com/matrix-dot-org/dendrite/builds/2826#39cff5de-e032-4ad0-ad26-f819e6919c42)
|
||||
it's `tests/50federation/40devicelists.pl` ) then to run Sytest:
|
||||
|
||||
```
|
||||
docker run --rm --name sytest
|
||||
-v "/Users/kegan/github/sytest:/sytest"
|
||||
|
|
@ -121,10 +129,12 @@ docker run --rm --name sytest
|
|||
-e "POSTGRES=1" -e "DENDRITE_TRACE_HTTP=1"
|
||||
matrixdotorg/sytest-dendrite:latest tests/50federation/40devicelists.pl
|
||||
```
|
||||
|
||||
See [sytest.md](docs/sytest.md) for the full description of these flags.
|
||||
|
||||
You can try running sytest outside of docker for faster runs, but the dependencies can be temperamental
|
||||
and we recommend using docker where possible.
|
||||
|
||||
```
|
||||
cd sytest
|
||||
export PERL5LIB=$HOME/lib/perl5
|
||||
|
|
@ -149,8 +159,9 @@ Dendrite in Monolith + SQLite works in a range of environments including iOS and
|
|||
|
||||
For small homeserver installations joined on ~10s rooms on matrix.org with ~100s of users in those rooms, including some
|
||||
encrypted rooms:
|
||||
- Memory: uses around 100MB of RAM, with peaks at around 200MB.
|
||||
- Disk space: After a few months of usage, the database grew to around 2GB (in Monolith mode).
|
||||
- CPU: Brief spikes when processing events, typically idles at 1% CPU.
|
||||
|
||||
- Memory: uses around 100MB of RAM, with peaks at around 200MB.
|
||||
- Disk space: After a few months of usage, the database grew to around 2GB (in Monolith mode).
|
||||
- CPU: Brief spikes when processing events, typically idles at 1% CPU.
|
||||
|
||||
This means Dendrite should comfortably work on things like Raspberry Pis.
|
||||
|
|
|
|||
|
|
@ -88,7 +88,16 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
|||
}
|
||||
|
||||
events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event}
|
||||
events = append(events, output.NewRoomEvent.AddStateEvents...)
|
||||
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
|
||||
eventsReq := &api.QueryEventsByIDRequest{
|
||||
EventIDs: output.NewRoomEvent.AddsStateEventIDs,
|
||||
}
|
||||
eventsRes := &api.QueryEventsByIDResponse{}
|
||||
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
|
||||
return false
|
||||
}
|
||||
events = append(events, eventsRes.Events...)
|
||||
}
|
||||
|
||||
// Send event to any relevant application services
|
||||
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil {
|
||||
|
|
|
|||
|
|
@ -318,6 +318,17 @@ user_api:
|
|||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for the Push Server API.
|
||||
push_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7782
|
||||
connect: http://localhost:7782
|
||||
database:
|
||||
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for Opentracing.
|
||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||
# how this works and how to set it up.
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(m.userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
|
|||
delete(u.Sessions, sessionID)
|
||||
}
|
||||
|
||||
type Challenge struct {
|
||||
Completed []string `json:"completed"`
|
||||
Flows []userInteractiveFlow `json:"flows"`
|
||||
Session string `json:"session"`
|
||||
// TODO: Return any additional `params`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||
return &util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: struct {
|
||||
Completed []string `json:"completed"`
|
||||
Flows []userInteractiveFlow `json:"flows"`
|
||||
Session string `json:"session"`
|
||||
// TODO: Return any additional `params`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}{
|
||||
u.Completed,
|
||||
u.Flows,
|
||||
sessionID,
|
||||
make(map[string]interface{}),
|
||||
JSON: Challenge{
|
||||
Completed: u.Completed,
|
||||
Flows: u.Flows,
|
||||
Session: sessionID,
|
||||
Params: make(map[string]interface{}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ func AddPublicRoutes(
|
|||
routing.Setup(
|
||||
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
|
||||
accountsDB, userAPI, federation,
|
||||
syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg,
|
||||
syncProducer, transactionsCache, fsAPI, keyAPI,
|
||||
extRoomsProvider, mscCfg,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ func BadJSON(msg string) *MatrixError {
|
|||
return &MatrixError{"M_BAD_JSON", msg}
|
||||
}
|
||||
|
||||
// BadAlias is an error when the client supplies a bad alias.
|
||||
func BadAlias(msg string) *MatrixError {
|
||||
return &MatrixError{"M_BAD_ALIAS", msg}
|
||||
}
|
||||
|
||||
// NotJSON is an error when the client supplies something that is not JSON
|
||||
// to a JSON endpoint.
|
||||
func NotJSON(msg string) *MatrixError {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ type SyncAPIProducer struct {
|
|||
}
|
||||
|
||||
// SendData sends account data to the sync API server
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error {
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
|
|
@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
|
|||
m.Header.Set(jetstream.UserID, userID)
|
||||
|
||||
data := eventutil.AccountData{
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
ReadMarker: readMarker,
|
||||
}
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
|
|
@ -127,7 +128,7 @@ func SaveAccountData(
|
|||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
@ -138,11 +139,6 @@ func SaveAccountData(
|
|||
}
|
||||
}
|
||||
|
||||
type readMarkerJSON struct {
|
||||
FullyRead string `json:"m.fully_read"`
|
||||
Read string `json:"m.read"`
|
||||
}
|
||||
|
||||
type fullyReadEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
|
|
@ -159,7 +155,7 @@ func SaveReadMarker(
|
|||
return *resErr
|
||||
}
|
||||
|
||||
var r readMarkerJSON
|
||||
var r eventutil.ReadMarkerJSON
|
||||
resErr = httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
|
|
@ -189,7 +185,7 @@ func SaveReadMarker(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil {
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
|
||||
|
|
@ -163,6 +164,15 @@ func DeleteDeviceById(
|
|||
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
|
||||
deviceID string,
|
||||
) util.JSONResponse {
|
||||
var (
|
||||
deleteOK bool
|
||||
sessionID string
|
||||
)
|
||||
defer func() {
|
||||
if deleteOK {
|
||||
sessions.deleteSession(sessionID)
|
||||
}
|
||||
}()
|
||||
ctx := req.Context()
|
||||
defer req.Body.Close() // nolint:errcheck
|
||||
bodyBytes, err := ioutil.ReadAll(req.Body)
|
||||
|
|
@ -172,8 +182,29 @@ func DeleteDeviceById(
|
|||
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
// check that we know this session, and it matches with the device to delete
|
||||
s := gjson.GetBytes(bodyBytes, "auth.session").Str
|
||||
if dev, ok := sessions.getDeviceToDelete(s); ok {
|
||||
if dev != deviceID {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("session & device mismatch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s != "" {
|
||||
sessionID = s
|
||||
}
|
||||
|
||||
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
|
||||
if errRes != nil {
|
||||
switch data := errRes.JSON.(type) {
|
||||
case auth.Challenge:
|
||||
sessions.addDeviceToDelete(data.Session, deviceID)
|
||||
default:
|
||||
}
|
||||
return *errRes
|
||||
}
|
||||
|
||||
|
|
@ -201,6 +232,8 @@ func DeleteDeviceById(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
deleteOK = true
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
|
|
|
|||
|
|
@ -139,11 +139,17 @@ func SetLocalAlias(
|
|||
// TODO: This code should eventually be refactored with:
|
||||
// 1. The new method for checking for things matching an AS's namespace
|
||||
// 2. Using an overall Regex object for all AS's just like we did for usernames
|
||||
|
||||
reqUserID, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("User ID must be in the form '@localpart:domain'"),
|
||||
}
|
||||
}
|
||||
for _, appservice := range cfg.Derived.ApplicationServices {
|
||||
// Don't prevent AS from creating aliases in its own namespace
|
||||
// Note that Dendrite uses SenderLocalpart as UserID for AS users
|
||||
if device.UserID != appservice.SenderLocalpart {
|
||||
if reqUserID != appservice.SenderLocalpart {
|
||||
if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok {
|
||||
for _, namespace := range aliasNamespaces {
|
||||
if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) {
|
||||
|
|
|
|||
63
clientapi/routing/notification.go
Normal file
63
clientapi/routing/notification.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// GetNotifications handles /_matrix/client/r0/notifications
|
||||
func GetNotifications(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
var limit int64
|
||||
if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
|
||||
var err error
|
||||
limit, err = strconv.ParseInt(limitStr, 10, 64)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
|
||||
var queryRes userapi.QueryNotificationsResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||
Localpart: localpart,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: queryRes,
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,7 @@ import (
|
|||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type newPasswordRequest struct {
|
||||
|
|
@ -37,6 +38,11 @@ func Password(
|
|||
var r newPasswordRequest
|
||||
r.LogoutDevices = true
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"sessionId": device.SessionID,
|
||||
"userId": device.UserID,
|
||||
}).Debug("Changing password")
|
||||
|
||||
// Unmarshal the request.
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
|
|
@ -116,6 +122,15 @@ func Password(
|
|||
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success code.
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ func SetDisplayName(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil {
|
||||
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
|
|||
114
clientapi/routing/pusher.go
Normal file
114
clientapi/routing/pusher.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// GetPushers handles /_matrix/client/r0/pushers
|
||||
func GetPushers(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
var queryRes userapi.QueryPushersResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||
Localpart: localpart,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
for i := range queryRes.Pushers {
|
||||
queryRes.Pushers[i].SessionID = 0
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: queryRes,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPusher handles /_matrix/client/r0/pushers/set
|
||||
// This endpoint allows the creation, modification and deletion of pushers for this user ID.
|
||||
// The behaviour of this endpoint varies depending on the values in the JSON body.
|
||||
func SetPusher(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
body := userapi.PerformPusherSetRequest{}
|
||||
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
if len(body.AppID) > 64 {
|
||||
return invalidParam("length of app_id must be no more than 64 characters")
|
||||
}
|
||||
if len(body.PushKey) > 512 {
|
||||
return invalidParam("length of pushkey must be no more than 512 bytes")
|
||||
}
|
||||
uInt := body.Data["url"]
|
||||
if uInt != nil {
|
||||
u, ok := uInt.(string)
|
||||
if !ok {
|
||||
return invalidParam("url must be string")
|
||||
}
|
||||
if u != "" {
|
||||
var pushUrl *url.URL
|
||||
pushUrl, err = url.Parse(u)
|
||||
if err != nil {
|
||||
return invalidParam("malformed url passed")
|
||||
}
|
||||
if pushUrl.Scheme != "https" {
|
||||
return invalidParam("only https scheme is allowed")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
body.Localpart = localpart
|
||||
body.SessionID = device.SessionID
|
||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func invalidParam(msg string) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidParam(msg),
|
||||
}
|
||||
}
|
||||
386
clientapi/routing/pushrules.go
Normal file
386
clientapi/routing/pushrules.go
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
|
||||
if eerr, ok := err.(*jsonerror.MatrixError); ok {
|
||||
var status int
|
||||
switch eerr.ErrCode {
|
||||
case "M_INVALID_ARGUMENT_VALUE":
|
||||
status = http.StatusBadRequest
|
||||
case "M_NOT_FOUND":
|
||||
status = http.StatusNotFound
|
||||
default:
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
|
||||
}
|
||||
util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: ruleSets,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: ruleSet,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: *rulesPtr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: (*rulesPtr)[i],
|
||||
}
|
||||
}
|
||||
|
||||
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
var newRule pushrules.Rule
|
||||
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
|
||||
return errorResponse(ctx, err, "JSON Decode failed")
|
||||
}
|
||||
newRule.RuleID = ruleID
|
||||
|
||||
errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
|
||||
if len(errs) > 0 {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
|
||||
}
|
||||
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
|
||||
// Modify rule at the same index.
|
||||
|
||||
// TODO: The spec does not say what to do in this case, but
|
||||
// this feels reasonable.
|
||||
*((*rulesPtr)[i]) = newRule
|
||||
util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
|
||||
} else {
|
||||
if i >= 0 {
|
||||
// Delete old rule.
|
||||
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||
util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
|
||||
} else {
|
||||
// SPEC: When creating push rules, they MUST be enabled by default.
|
||||
//
|
||||
// TODO: it's unclear if we must reject disabled rules, or force
|
||||
// the value to true. Sytests fail if we don't force it.
|
||||
newRule.Enabled = true
|
||||
}
|
||||
|
||||
// Add new rule.
|
||||
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
|
||||
}
|
||||
|
||||
*rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
|
||||
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
|
||||
}
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
|
||||
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
attrGet, err := pushRuleAttrGetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||
}
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: map[string]interface{}{
|
||||
attr: attrGet((*rulesPtr)[i]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
|
||||
var newPartialRule pushrules.Rule
|
||||
if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(err.Error()),
|
||||
}
|
||||
}
|
||||
if newPartialRule.Actions == nil {
|
||||
// This ensures json.Marshal encodes the empty list as [] rather than null.
|
||||
newPartialRule.Actions = []*pushrules.Action{}
|
||||
}
|
||||
|
||||
attrGet, err := pushRuleAttrGetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||
}
|
||||
attrSet, err := pushRuleAttrSetter(attr)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
|
||||
}
|
||||
|
||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
||||
if err != nil {
|
||||
return errorResponse(ctx, err, "queryPushRules failed")
|
||||
}
|
||||
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
|
||||
if ruleSet == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||
}
|
||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||
if rulesPtr == nil {
|
||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||
}
|
||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||
if i < 0 {
|
||||
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
|
||||
attrSet((*rulesPtr)[i], &newPartialRule)
|
||||
|
||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
||||
return errorResponse(ctx, err, "putPushRules failed")
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||
}
|
||||
|
||||
func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
|
||||
var res userapi.QueryPushRulesResponse
|
||||
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
|
||||
return nil, err
|
||||
}
|
||||
return res.RuleSets, nil
|
||||
}
|
||||
|
||||
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
|
||||
req := userapi.PerformPushRulesPutRequest{
|
||||
UserID: userID,
|
||||
RuleSets: ruleSets,
|
||||
}
|
||||
var res struct{}
|
||||
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
|
||||
switch scope {
|
||||
case pushrules.GlobalScope:
|
||||
return &ruleSets.Global
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
|
||||
switch kind {
|
||||
case pushrules.OverrideKind:
|
||||
return &ruleSet.Override
|
||||
case pushrules.ContentKind:
|
||||
return &ruleSet.Content
|
||||
case pushrules.RoomKind:
|
||||
return &ruleSet.Room
|
||||
case pushrules.SenderKind:
|
||||
return &ruleSet.Sender
|
||||
case pushrules.UnderrideKind:
|
||||
return &ruleSet.Underride
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
|
||||
for i, rule := range rules {
|
||||
if rule.RuleID == id {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
|
||||
switch attr {
|
||||
case "actions":
|
||||
return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
|
||||
case "enabled":
|
||||
return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
|
||||
default:
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||
}
|
||||
}
|
||||
|
||||
func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
|
||||
switch attr {
|
||||
case "actions":
|
||||
return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
|
||||
case "enabled":
|
||||
return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
|
||||
default:
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
|
||||
}
|
||||
}
|
||||
|
||||
func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
|
||||
var i int
|
||||
|
||||
if afterID != "" {
|
||||
for ; i < len(rules); i++ {
|
||||
if rules[i].RuleID == afterID {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == len(rules) {
|
||||
return 0, jsonerror.NotFound("after: rule ID not found")
|
||||
}
|
||||
if rules[i].Default {
|
||||
return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
|
||||
}
|
||||
// We stopped on the "after" match to differentiate
|
||||
// not-found from is-last-entry. Now we move to the earliest
|
||||
// insertion point.
|
||||
i++
|
||||
}
|
||||
|
||||
if beforeID != "" {
|
||||
for ; i < len(rules); i++ {
|
||||
if rules[i].RuleID == beforeID {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == len(rules) {
|
||||
return 0, jsonerror.NotFound("before: rule ID not found")
|
||||
}
|
||||
if rules[i].Default {
|
||||
return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
|
||||
}
|
||||
}
|
||||
|
||||
// UNSPEC: The spec does not say what to do if no after/before is
|
||||
// given. Sytest fails if it doesn't go first.
|
||||
return i, nil
|
||||
}
|
||||
|
|
@ -76,6 +76,10 @@ type sessionsDict struct {
|
|||
sessions map[string][]authtypes.LoginType
|
||||
params map[string]registerRequest
|
||||
timer map[string]*time.Timer
|
||||
// deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused.
|
||||
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||
deleteSessionToDeviceID map[string]string
|
||||
}
|
||||
|
||||
// defaultTimeout is the timeout used to clean up sessions
|
||||
|
|
@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
|||
defer d.Unlock()
|
||||
delete(d.params, sessionID)
|
||||
delete(d.sessions, sessionID)
|
||||
delete(d.deleteSessionToDeviceID, sessionID)
|
||||
// stop the timer, e.g. because the registration was completed
|
||||
if t, ok := d.timer[sessionID]; ok {
|
||||
if !t.Stop() {
|
||||
|
|
@ -129,9 +134,10 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
|||
|
||||
func newSessionsDict() *sessionsDict {
|
||||
return &sessionsDict{
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
deleteSessionToDeviceID: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype
|
|||
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||
}
|
||||
|
||||
func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.deleteSessionToDeviceID[sessionID] = deviceID
|
||||
}
|
||||
|
||||
func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
deviceID, ok := d.deleteSessionToDeviceID[sessionID]
|
||||
return deviceID, ok
|
||||
}
|
||||
|
||||
var (
|
||||
sessions = newSessionsDict()
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
|
|
|
|||
|
|
@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
s := newSessionsDict()
|
||||
|
||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld"
|
||||
// manually added, as s.addParams() would start the timer with the default timeout
|
||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld2"
|
||||
s.startTimer(time.Minute, dummySession)
|
||||
s.deleteSession(dummySession)
|
||||
|
|
@ -236,18 +236,28 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// t.Parallel()
|
||||
dummySession := "helloWorld3"
|
||||
// the following will start a timer with the default timeout of 5min
|
||||
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
||||
s.addDeviceToDelete(dummySession, "dummyDevice")
|
||||
s.getCompletedStages(dummySession)
|
||||
// reset the timer with a lower timeout
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
if _, ok := s.timer[dummySession]; ok {
|
||||
t.Error("expected timer to be delete")
|
||||
}
|
||||
if _, ok := s.sessions[dummySession]; ok {
|
||||
t.Error("expected session to be delete")
|
||||
}
|
||||
if _, ok := s.getDeviceToDelete(dummySession); ok {
|
||||
t.Error("expected session to device to be delete")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func PutTag(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ func DeleteTag(
|
|||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
|
@ -561,25 +560,142 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
|
||||
// TODO: Implement push rules API
|
||||
res := json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
"room": [],
|
||||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`)
|
||||
// Push rules
|
||||
|
||||
v3mux.Handle("/pushrules",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: &res,
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetAllPushRules(req.Context(), device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
|
||||
}
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
query := req.URL.Query()
|
||||
return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodDelete)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
// Element user settings
|
||||
|
||||
v3mux.Handle("/profile/{userID}",
|
||||
|
|
@ -885,6 +1001,27 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/notifications",
|
||||
httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetNotifications(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushers",
|
||||
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetPushers(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/pushers/set",
|
||||
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return SetPusher(req, device, userAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// Stub implementations for sytest
|
||||
v3mux.Handle("/events",
|
||||
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -120,6 +122,40 @@ func SendEvent(
|
|||
}
|
||||
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
||||
|
||||
// validate that the aliases exists
|
||||
if eventType == gomatrixserverlib.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" {
|
||||
aliasReq := api.AliasEvent{}
|
||||
if err = json.Unmarshal(e.Content(), &aliasReq); err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err))
|
||||
}
|
||||
if !aliasReq.Valid() {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidParam("Request contains invalid aliases."),
|
||||
}
|
||||
}
|
||||
aliasRes := &api.GetAliasesForRoomIDResponse{}
|
||||
if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
var found int
|
||||
requestAliases := append(aliasReq.AltAliases, aliasReq.Alias)
|
||||
for _, alias := range aliasRes.Aliases {
|
||||
for _, altAlias := range requestAliases {
|
||||
if altAlias == alias {
|
||||
found++
|
||||
}
|
||||
}
|
||||
}
|
||||
// check that we found at least the same amount of existing aliases as are in the request
|
||||
if aliasReq.Alias != "" && found < len(requestAliases) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadAlias("No matching alias found."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var txnAndSessionID *api.TransactionID
|
||||
if txnID != nil {
|
||||
txnAndSessionID = &api.TransactionID{
|
||||
|
|
|
|||
|
|
@ -48,24 +48,27 @@ Example:
|
|||
# read password from stdin
|
||||
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
|
||||
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
|
||||
# reset password for a user, can be used with a combination above to read the password
|
||||
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
|
||||
)
|
||||
|
||||
func main() {
|
||||
name := os.Args[0]
|
||||
flag.Usage = func() {
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
cfg := setup.ParseFlags(true)
|
||||
|
|
@ -93,6 +96,19 @@ func main() {
|
|||
if *isAdmin {
|
||||
accType = api.AccountTypeAdmin
|
||||
}
|
||||
|
||||
if *resetPassword {
|
||||
err = accountDB.SetPassword(context.Background(), *username, pass)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
|
||||
}
|
||||
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
|
||||
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
|
||||
}
|
||||
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||
|
|
|
|||
|
|
@ -144,12 +144,14 @@ func main() {
|
|||
accountDB := base.Base.CreateAccountsDB()
|
||||
federation := createFederationClient(base)
|
||||
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(
|
||||
&base.Base,
|
||||
)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
&base.Base, cache.New(), userAPI,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ func main() {
|
|||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
|
|
|
|||
|
|
@ -111,14 +111,15 @@ func main() {
|
|||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
rsComponent := roomserver.NewInternalAPI(
|
||||
base,
|
||||
)
|
||||
rsAPI := rsComponent
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(
|
||||
base, cache.New(), userAPI,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,7 +106,8 @@ func main() {
|
|||
keyAPI = base.KeyServerHTTPClient()
|
||||
}
|
||||
|
||||
userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
|
||||
pgClient := base.PushGatewayHTTPClient()
|
||||
userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
|
||||
userAPI := userImpl
|
||||
if base.UseHTTPAPIs {
|
||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,11 @@ import (
|
|||
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
||||
accountDB := base.CreateAccountsDB()
|
||||
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
|
||||
userAPI := userapi.NewInternalAPI(
|
||||
base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
|
||||
base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
|
||||
base.PushGatewayHTTPClient(),
|
||||
)
|
||||
|
||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||
|
||||
|
|
|
|||
|
|
@ -184,13 +184,15 @@ func startup() {
|
|||
accountDB := base.CreateAccountsDB()
|
||||
federation := conn.CreateFederationClient(base, pSessions)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
||||
asQuery := appservice.NewInternalAPI(
|
||||
base, userAPI, rsAPI,
|
||||
|
|
|
|||
|
|
@ -212,6 +212,8 @@ func main() {
|
|||
rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
|
||||
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
|
||||
|
||||
psAPI := pushserver.NewInternalAPI(base)
|
||||
|
||||
monolith := setup.Monolith{
|
||||
Config: base.Cfg,
|
||||
AccountDB: accountDB,
|
||||
|
|
@ -225,6 +227,7 @@ func main() {
|
|||
RoomserverAPI: rsAPI,
|
||||
UserAPI: userAPI,
|
||||
KeyAPI: keyAPI,
|
||||
PushserverAPI: psAPI,
|
||||
//ServerKeyAPI: serverKeyAPI,
|
||||
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ global:
|
|||
# in monolith mode. It is required to specify the address of at least one
|
||||
# NATS Server node if running in polylith mode.
|
||||
addresses:
|
||||
# - localhost:4222
|
||||
# - localhost:4222
|
||||
|
||||
# Keep all NATS streams in memory, rather than persisting it to the storage
|
||||
# path below. This option is present primarily for integration testing and
|
||||
|
|
@ -346,11 +346,6 @@ user_api:
|
|||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: file:userapi_devices.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
# The length of time that a token issued for a relying party from
|
||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||
# is considered to be valid in milliseconds.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ type FederationClient interface {
|
|||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
||||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||
|
|
|
|||
|
|
@ -146,7 +146,28 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
|
|||
// processMessage updates the list of currently joined hosts in the room
|
||||
// and then sends the event to the hosts that were joined before the event.
|
||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
|
||||
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState()))
|
||||
eventsRes := &api.QueryEventsByIDResponse{}
|
||||
if len(ore.AddsStateEventIDs) > 0 {
|
||||
eventsReq := &api.QueryEventsByIDRequest{
|
||||
EventIDs: ore.AddsStateEventIDs,
|
||||
}
|
||||
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
|
||||
return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, event := range eventsRes.Events {
|
||||
if event.EventID() == ore.Event.EventID() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
eventsRes.Events = append(eventsRes.Events, ore.Event)
|
||||
}
|
||||
}
|
||||
|
||||
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(eventsRes.Events))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
|
|||
}
|
||||
|
||||
func (a *FederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
|
||||
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
|
||||
})
|
||||
if err != nil {
|
||||
return res, err
|
||||
|
|
|
|||
|
|
@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
|
|||
}
|
||||
|
||||
type spacesReq struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
Req gomatrixserverlib.MSC2946SpacesRequest
|
||||
RoomID string
|
||||
Res gomatrixserverlib.MSC2946SpacesResponse
|
||||
Err *api.FederationClientError
|
||||
S gomatrixserverlib.ServerName
|
||||
SuggestedOnly bool
|
||||
RoomID string
|
||||
Res gomatrixserverlib.MSC2946SpacesResponse
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
|
||||
defer span.Finish()
|
||||
|
||||
request := spacesReq{
|
||||
S: dst,
|
||||
Req: r,
|
||||
RoomID: roomID,
|
||||
S: dst,
|
||||
SuggestedOnly: suggestedOnly,
|
||||
RoomID: roomID,
|
||||
}
|
||||
var response spacesReq
|
||||
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
|
||||
|
|
|
|||
|
|
@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
|
|||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
|
||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
|
|
|
|||
11
go.mod
11
go.mod
|
|
@ -18,12 +18,13 @@ require (
|
|||
github.com/frankban/quicktest v1.14.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
||||
github.com/klauspost/compress v1.14.2 // indirect
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/libp2p/go-libp2p v0.13.0
|
||||
github.com/libp2p/go-libp2p-circuit v0.4.0
|
||||
|
|
@ -39,12 +40,12 @@ require (
|
|||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nats-io/nats-server/v2 v2.3.2
|
||||
github.com/nats-io/nats-server/v2 v2.7.3
|
||||
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
|
|
@ -61,11 +62,11 @@ require (
|
|||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.4.2
|
||||
go.uber.org/atomic v1.9.0
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
|
|
|
|||
16
go.sum
16
go.sum
|
|
@ -480,6 +480,7 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
|||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U=
|
||||
github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo=
|
||||
|
|
@ -712,9 +713,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
|
|||
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s=
|
||||
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw=
|
||||
github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
|
|
@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
|
|||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
|
|
@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
|
|||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
|
|
@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
FederationEventCacheName = "federation_event"
|
||||
FederationEventCacheMaxEntries = 256
|
||||
FederationEventCacheMutable = true // to allow use of Unset only
|
||||
FederationEventCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// FederationCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ const (
|
|||
RoomInfoCacheName = "roominfo"
|
||||
RoomInfoCacheMaxEntries = 1024
|
||||
RoomInfoCacheMutable = true
|
||||
RoomInfoCacheMaxAge = time.Minute * 5
|
||||
)
|
||||
|
||||
// RoomInfosCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
RoomServerRoomIDsCacheName = "roomserver_room_ids"
|
||||
RoomServerRoomIDsCacheMaxEntries = 1024
|
||||
RoomServerRoomIDsCacheMutable = false
|
||||
RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
type RoomServerCaches interface {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const (
|
|||
RoomVersionCacheName = "room_versions"
|
||||
RoomVersionCacheMaxEntries = 1024
|
||||
RoomVersionCacheMutable = false
|
||||
RoomVersionCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// RoomVersionsCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
ServerKeyCacheName = "server_key"
|
||||
ServerKeyCacheMaxEntries = 4096
|
||||
ServerKeyCacheMutable = true
|
||||
ServerKeyCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// ServerKeyCache contains the subset of functions needed for
|
||||
|
|
|
|||
33
internal/caching/cache_space_rooms.go
Normal file
33
internal/caching/cache_space_rooms.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
SpaceSummaryRoomsCacheName = "space_summary_rooms"
|
||||
SpaceSummaryRoomsCacheMaxEntries = 100
|
||||
SpaceSummaryRoomsCacheMutable = true
|
||||
SpaceSummaryRoomsCacheMaxAge = time.Minute * 5
|
||||
)
|
||||
|
||||
type SpaceSummaryRoomsCache interface {
|
||||
GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool)
|
||||
StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse)
|
||||
}
|
||||
|
||||
func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) {
|
||||
val, found := c.SpaceSummaryRooms.Get(roomID)
|
||||
if found && val != nil {
|
||||
if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok {
|
||||
return resp, true
|
||||
}
|
||||
}
|
||||
return r, false
|
||||
}
|
||||
|
||||
func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) {
|
||||
c.SpaceSummaryRooms.Set(roomID, r)
|
||||
}
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
package caching
|
||||
|
||||
import "time"
|
||||
|
||||
// Caches contains a set of references to caches. They may be
|
||||
// different implementations as long as they satisfy the Cache
|
||||
// interface.
|
||||
|
|
@ -10,6 +12,7 @@ type Caches struct {
|
|||
RoomServerRoomIDs Cache // RoomServerNIDsCache
|
||||
RoomInfos Cache // RoomInfoCache
|
||||
FederationEvents Cache // FederationEventsCache
|
||||
SpaceSummaryRooms Cache // SpaceSummaryRoomsCache
|
||||
}
|
||||
|
||||
// Cache is the interface that an implementation must satisfy.
|
||||
|
|
@ -18,3 +21,5 @@ type Cache interface {
|
|||
Set(key string, value interface{})
|
||||
Unset(key string)
|
||||
}
|
||||
|
||||
const CacheNoMaxAge = time.Duration(0)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomVersionCacheName,
|
||||
RoomVersionCacheMutable,
|
||||
RoomVersionCacheMaxEntries,
|
||||
RoomVersionCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
ServerKeyCacheName,
|
||||
ServerKeyCacheMutable,
|
||||
ServerKeyCacheMaxEntries,
|
||||
ServerKeyCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomServerRoomIDsCacheName,
|
||||
RoomServerRoomIDsCacheMutable,
|
||||
RoomServerRoomIDsCacheMaxEntries,
|
||||
RoomServerRoomIDsCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomInfoCacheName,
|
||||
RoomInfoCacheMutable,
|
||||
RoomInfoCacheMaxEntries,
|
||||
RoomInfoCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
FederationEventCacheName,
|
||||
FederationEventCacheMutable,
|
||||
FederationEventCacheMaxEntries,
|
||||
FederationEventCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
spaceRooms, err := NewInMemoryLRUCachePartition(
|
||||
SpaceSummaryRoomsCacheName,
|
||||
SpaceSummaryRoomsCacheMutable,
|
||||
SpaceSummaryRoomsCacheMaxEntries,
|
||||
SpaceSummaryRoomsCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
}
|
||||
go cacheCleaner(
|
||||
roomVersions, serverKeys, roomServerRoomIDs,
|
||||
roomInfos, federationEvents,
|
||||
roomInfos, federationEvents, spaceRooms,
|
||||
)
|
||||
return &Caches{
|
||||
RoomVersions: roomVersions,
|
||||
|
|
@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomServerRoomIDs: roomServerRoomIDs,
|
||||
RoomInfos: roomInfos,
|
||||
FederationEvents: federationEvents,
|
||||
SpaceSummaryRooms: spaceRooms,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct {
|
|||
name string
|
||||
mutable bool
|
||||
maxEntries int
|
||||
maxAge time.Duration
|
||||
lru *lru.Cache
|
||||
}
|
||||
|
||||
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
||||
type inMemoryLRUCacheEntry struct {
|
||||
value interface{}
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
||||
var err error
|
||||
cache := InMemoryLRUCachePartition{
|
||||
name: name,
|
||||
mutable: mutable,
|
||||
maxEntries: maxEntries,
|
||||
maxAge: maxAge,
|
||||
}
|
||||
cache.lru, err = lru.New(maxEntries)
|
||||
if err != nil {
|
||||
|
|
@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena
|
|||
|
||||
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
|
||||
if !c.mutable {
|
||||
if peek, ok := c.lru.Peek(key); ok && peek != value {
|
||||
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
|
||||
if peek, ok := c.lru.Peek(key); ok {
|
||||
if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value {
|
||||
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.lru.Add(key, value)
|
||||
c.lru.Add(key, &inMemoryLRUCacheEntry{
|
||||
value: value,
|
||||
created: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *InMemoryLRUCachePartition) Unset(key string) {
|
||||
|
|
@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) {
|
|||
}
|
||||
|
||||
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
|
||||
return c.lru.Get(key)
|
||||
v, ok := c.lru.Get(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
entry, ok := v.(*inMemoryLRUCacheEntry)
|
||||
switch {
|
||||
case ok && c.maxAge == CacheNoMaxAge:
|
||||
return entry.value, ok // There's no maximum age policy
|
||||
case ok && time.Since(entry.created) < c.maxAge:
|
||||
return entry.value, ok // The value for the key isn't stale
|
||||
default:
|
||||
// Either the key was found and it was stale, or the key
|
||||
// wasn't found at all
|
||||
c.lru.Remove(key)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
|
|||
// AccountData represents account data sent from the client API server to the
|
||||
// sync API server
|
||||
type AccountData struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
|
||||
}
|
||||
|
||||
type ReadMarkerJSON struct {
|
||||
FullyRead string `json:"m.fully_read"`
|
||||
Read string `json:"m.read"`
|
||||
}
|
||||
|
||||
// NotificationData contains statistics about notifications, sent from
|
||||
// the Push Server to the Sync API server.
|
||||
type NotificationData struct {
|
||||
// RoomID identifies the scope of the statistics, together with
|
||||
// MXID (which is encoded in the Kafka key).
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// HighlightCount is the number of unread notifications with the
|
||||
// highlight tweak.
|
||||
UnreadHighlightCount int `json:"unread_highlight_count"`
|
||||
|
||||
// UnreadNotificationCount is the total number of unread
|
||||
// notifications.
|
||||
UnreadNotificationCount int `json:"unread_notification_count"`
|
||||
}
|
||||
|
||||
// ProfileResponse is a struct containing all known user profile data
|
||||
|
|
|
|||
66
internal/pushgateway/client.go
Normal file
66
internal/pushgateway/client.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package pushgateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
type httpClient struct {
|
||||
hc *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a new Push Gateway client.
|
||||
func NewHTTPClient(disableTLSValidation bool) Client {
|
||||
hc := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: disableTLSValidation,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &httpClient{hc: hc}
|
||||
}
|
||||
|
||||
func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
|
||||
defer span.Finish()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hreq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
hresp, err := h.hc.Do(hreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint:errcheck
|
||||
defer hresp.Body.Close()
|
||||
|
||||
if hresp.StatusCode == http.StatusOK {
|
||||
return json.NewDecoder(hresp.Body).Decode(resp)
|
||||
}
|
||||
|
||||
var errorBody struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
|
||||
return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
|
||||
}
|
||||
return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
|
||||
}
|
||||
62
internal/pushgateway/pushgateway.go
Normal file
62
internal/pushgateway/pushgateway.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package pushgateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Client is how interactions with a Push Gateway is done.
|
||||
type Client interface {
|
||||
// Notify sends a notification to the gateway at the given URL.
|
||||
Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
|
||||
}
|
||||
|
||||
type NotifyRequest struct {
|
||||
Notification Notification `json:"notification"` // Required
|
||||
}
|
||||
|
||||
type NotifyResponse struct {
|
||||
// Rejected is the list of device push keys that were rejected
|
||||
// during the push. The caller should remove the push keys so they
|
||||
// are not used again.
|
||||
Rejected []string `json:"rejected"` // Required
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Counts *Counts `json:"counts,omitempty"`
|
||||
Devices []*Device `json:"devices"` // Required
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
ID string `json:"id,omitempty"` // Deprecated name for EventID.
|
||||
Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
|
||||
Prio Prio `json:"prio,omitempty"`
|
||||
RoomAlias string `json:"room_alias,omitempty"`
|
||||
RoomID string `json:"room_id,omitempty"`
|
||||
RoomName string `json:"room_name,omitempty"`
|
||||
Sender string `json:"sender,omitempty"`
|
||||
SenderDisplayName string `json:"sender_display_name,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
UserIsTarget bool `json:"user_is_target,omitempty"`
|
||||
}
|
||||
|
||||
type Counts struct {
|
||||
MissedCalls int `json:"missed_calls,omitempty"`
|
||||
Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
AppID string `json:"app_id"` // Required
|
||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||
PushKey string `json:"pushkey"` // Required
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||
}
|
||||
|
||||
type Prio string
|
||||
|
||||
const (
|
||||
HighPrio Prio = "high"
|
||||
LowPrio Prio = "low"
|
||||
)
|
||||
102
internal/pushrules/action.go
Normal file
102
internal/pushrules/action.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// An Action is (part of) an outcome of a rule. There are
|
||||
// (unofficially) terminal actions, and modifier actions.
|
||||
type Action struct {
|
||||
// Kind is the type of action. Has custom encoding in JSON.
|
||||
Kind ActionKind `json:"-"`
|
||||
|
||||
// Tweak is the property to tweak. Has custom encoding in JSON.
|
||||
Tweak TweakKey `json:"-"`
|
||||
|
||||
// Value is some value interpreted according to Kind and Tweak.
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
func (a *Action) MarshalJSON() ([]byte, error) {
|
||||
if a.Tweak == UnknownTweak && a.Value == nil {
|
||||
return json.Marshal(a.Kind)
|
||||
}
|
||||
|
||||
if a.Kind != SetTweakAction {
|
||||
return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
|
||||
}
|
||||
|
||||
m := map[string]interface{}{
|
||||
string(a.Kind): a.Tweak,
|
||||
}
|
||||
if a.Value != nil {
|
||||
m["value"] = a.Value
|
||||
}
|
||||
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (a *Action) UnmarshalJSON(bs []byte) error {
|
||||
if bytes.HasPrefix(bs, []byte("\"")) {
|
||||
return json.Unmarshal(bs, &a.Kind)
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
SetTweak TweakKey `json:"set_tweak"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(bs, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
if raw.SetTweak == UnknownTweak {
|
||||
return fmt.Errorf("got unknown action JSON: %s", string(bs))
|
||||
}
|
||||
a.Kind = SetTweakAction
|
||||
a.Tweak = raw.SetTweak
|
||||
if raw.Value != nil {
|
||||
a.Value = raw.Value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActionKind is the primary discriminator for actions.
|
||||
type ActionKind string
|
||||
|
||||
const (
|
||||
UnknownAction ActionKind = ""
|
||||
|
||||
// NotifyAction indicates the clients should show a notification.
|
||||
NotifyAction ActionKind = "notify"
|
||||
|
||||
// DontNotifyAction indicates the clients should not show a notification.
|
||||
DontNotifyAction ActionKind = "dont_notify"
|
||||
|
||||
// CoalesceAction tells the clients to show a notification, and
|
||||
// tells both servers and clients that multiple events can be
|
||||
// coalesced into a single notification. The behaviour is
|
||||
// implementation-specific.
|
||||
CoalesceAction ActionKind = "coalesce"
|
||||
|
||||
// SetTweakAction uses the Tweak and Value fields to add a
|
||||
// tweak. Multiple SetTweakAction can be provided in a rule,
|
||||
// combined with NotifyAction or CoalesceAction.
|
||||
SetTweakAction ActionKind = "set_tweak"
|
||||
)
|
||||
|
||||
// A TweakKey describes a property to be modified/tweaked for events
|
||||
// that match the rule.
|
||||
type TweakKey string
|
||||
|
||||
const (
|
||||
UnknownTweak TweakKey = ""
|
||||
|
||||
// SoundTweak describes which sound to play. Using "default" means
|
||||
// "enable sound".
|
||||
SoundTweak TweakKey = "sound"
|
||||
|
||||
// HighlightTweak asks the clients to highlight the conversation.
|
||||
HighlightTweak TweakKey = "highlight"
|
||||
)
|
||||
39
internal/pushrules/action_test.go
Normal file
39
internal/pushrules/action_test.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestActionJSON(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Want Action
|
||||
}{
|
||||
{Action{Kind: NotifyAction}},
|
||||
{Action{Kind: DontNotifyAction}},
|
||||
{Action{Kind: CoalesceAction}},
|
||||
{Action{Kind: SetTweakAction}},
|
||||
|
||||
{Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
|
||||
{Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
|
||||
{Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
|
||||
bs, err := json.Marshal(&tst.Want)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var got Action
|
||||
if err := json.Unmarshal(bs, &got); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("+got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
49
internal/pushrules/condition.go
Normal file
49
internal/pushrules/condition.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package pushrules
|
||||
|
||||
// A Condition dictates extra conditions for a matching rules. See
|
||||
// ConditionKind.
|
||||
type Condition struct {
|
||||
// Kind is the primary discriminator for the condition
|
||||
// type. Required.
|
||||
Kind ConditionKind `json:"kind"`
|
||||
|
||||
// Key indicates the dot-separated path of Event fields to
|
||||
// match. Required for EventMatchCondition and
|
||||
// SenderNotificationPermissionCondition.
|
||||
Key string `json:"key,omitempty"`
|
||||
|
||||
// Pattern indicates the value pattern that must match. Required
|
||||
// for EventMatchCondition.
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
|
||||
// Is indicates the condition that must be fulfilled. Required for
|
||||
// RoomMemberCountCondition.
|
||||
Is string `json:"is,omitempty"`
|
||||
}
|
||||
|
||||
// ConditionKind represents a kind of condition.
|
||||
//
|
||||
// SPEC: Unrecognised conditions MUST NOT match any events,
|
||||
// effectively making the push rule disabled.
|
||||
type ConditionKind string
|
||||
|
||||
const (
|
||||
UnknownCondition ConditionKind = ""
|
||||
|
||||
// EventMatchCondition indicates the condition looks for a key
|
||||
// path and matches a pattern. How paths that don't reference a
|
||||
// simple value match against rules is implementation-specific.
|
||||
EventMatchCondition ConditionKind = "event_match"
|
||||
|
||||
// ContainsDisplayNameCondition indicates the current user's
|
||||
// display name must be found in the content body.
|
||||
ContainsDisplayNameCondition ConditionKind = "contains_display_name"
|
||||
|
||||
// RoomMemberCountCondition matches a simple arithmetic comparison
|
||||
// against the total number of members in a room.
|
||||
RoomMemberCountCondition ConditionKind = "room_member_count"
|
||||
|
||||
// SenderNotificationPermissionCondition compares power level for
|
||||
// the sender in the event's room.
|
||||
SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
|
||||
)
|
||||
23
internal/pushrules/default.go
Normal file
23
internal/pushrules/default.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// DefaultAccountRuleSets is the complete set of default push rules
|
||||
// for an account.
|
||||
func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
|
||||
return &AccountRuleSets{
|
||||
Global: *DefaultGlobalRuleSet(localpart, serverName),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultGlobalRuleSet returns the default ruleset for a given (fully
|
||||
// qualified) MXID.
|
||||
func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
|
||||
return &RuleSet{
|
||||
Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
|
||||
Content: defaultContentRules(localpart),
|
||||
Underride: defaultUnderrideRules,
|
||||
}
|
||||
}
|
||||
33
internal/pushrules/default_content.go
Normal file
33
internal/pushrules/default_content.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package pushrules
|
||||
|
||||
func defaultContentRules(localpart string) []*Rule {
|
||||
return []*Rule{
|
||||
mRuleContainsUserNameDefinition(localpart),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
MRuleContainsUserName = ".m.rule.contains_user_name"
|
||||
)
|
||||
|
||||
func mRuleContainsUserNameDefinition(localpart string) *Rule {
|
||||
return &Rule{
|
||||
RuleID: MRuleContainsUserName,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Pattern: localpart,
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
165
internal/pushrules/default_override.go
Normal file
165
internal/pushrules/default_override.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package pushrules
|
||||
|
||||
func defaultOverrideRules(userID string) []*Rule {
|
||||
return []*Rule{
|
||||
&mRuleMasterDefinition,
|
||||
&mRuleSuppressNoticesDefinition,
|
||||
mRuleInviteForMeDefinition(userID),
|
||||
&mRuleMemberEventDefinition,
|
||||
&mRuleContainsDisplayNameDefinition,
|
||||
&mRuleTombstoneDefinition,
|
||||
&mRuleRoomNotifDefinition,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
MRuleMaster = ".m.rule.master"
|
||||
MRuleSuppressNotices = ".m.rule.suppress_notices"
|
||||
MRuleInviteForMe = ".m.rule.invite_for_me"
|
||||
MRuleMemberEvent = ".m.rule.member_event"
|
||||
MRuleContainsDisplayName = ".m.rule.contains_display_name"
|
||||
MRuleTombstone = ".m.rule.tombstone"
|
||||
MRuleRoomNotif = ".m.rule.roomnotif"
|
||||
)
|
||||
|
||||
var (
|
||||
mRuleMasterDefinition = Rule{
|
||||
RuleID: MRuleMaster,
|
||||
Default: true,
|
||||
Enabled: false,
|
||||
Conditions: []*Condition{},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleSuppressNoticesDefinition = Rule{
|
||||
RuleID: MRuleSuppressNotices,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.msgtype",
|
||||
Pattern: "m.notice",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleMemberEventDefinition = Rule{
|
||||
RuleID: MRuleMemberEvent,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.member",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||
}
|
||||
mRuleContainsDisplayNameDefinition = Rule{
|
||||
RuleID: MRuleContainsDisplayName,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleTombstoneDefinition = Rule{
|
||||
RuleID: MRuleTombstone,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.tombstone",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "state_key",
|
||||
Pattern: "",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleRoomNotifDefinition = Rule{
|
||||
RuleID: MRuleRoomNotif,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.body",
|
||||
Pattern: "@room",
|
||||
},
|
||||
{
|
||||
Kind: SenderNotificationPermissionCondition,
|
||||
Key: "room",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func mRuleInviteForMeDefinition(userID string) *Rule {
|
||||
return &Rule{
|
||||
RuleID: MRuleInviteForMe,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.member",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "content.membership",
|
||||
Pattern: "invite",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "state_key",
|
||||
Pattern: userID,
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "default",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
119
internal/pushrules/default_underride.go
Normal file
119
internal/pushrules/default_underride.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package pushrules
|
||||
|
||||
const (
|
||||
MRuleCall = ".m.rule.call"
|
||||
MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
|
||||
MRuleRoomOneToOne = ".m.rule.room_one_to_one"
|
||||
MRuleMessage = ".m.rule.message"
|
||||
MRuleEncrypted = ".m.rule.encrypted"
|
||||
)
|
||||
|
||||
var defaultUnderrideRules = []*Rule{
|
||||
&mRuleCallDefinition,
|
||||
&mRuleEncryptedRoomOneToOneDefinition,
|
||||
&mRuleRoomOneToOneDefinition,
|
||||
&mRuleMessageDefinition,
|
||||
&mRuleEncryptedDefinition,
|
||||
}
|
||||
|
||||
var (
|
||||
mRuleCallDefinition = Rule{
|
||||
RuleID: MRuleCall,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.call.invite",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: SoundTweak,
|
||||
Value: "ring",
|
||||
},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleEncryptedRoomOneToOneDefinition = Rule{
|
||||
RuleID: MRuleEncryptedRoomOneToOne,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: RoomMemberCountCondition,
|
||||
Is: "2",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.encrypted",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleRoomOneToOneDefinition = Rule{
|
||||
RuleID: MRuleRoomOneToOne,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: RoomMemberCountCondition,
|
||||
Is: "2",
|
||||
},
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.message",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{
|
||||
{Kind: NotifyAction},
|
||||
{
|
||||
Kind: SetTweakAction,
|
||||
Tweak: HighlightTweak,
|
||||
Value: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
mRuleMessageDefinition = Rule{
|
||||
RuleID: MRuleMessage,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.message",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: NotifyAction}},
|
||||
}
|
||||
mRuleEncryptedDefinition = Rule{
|
||||
RuleID: MRuleEncrypted,
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
Conditions: []*Condition{
|
||||
{
|
||||
Kind: EventMatchCondition,
|
||||
Key: "type",
|
||||
Pattern: "m.room.encrypted",
|
||||
},
|
||||
},
|
||||
Actions: []*Action{{Kind: NotifyAction}},
|
||||
}
|
||||
)
|
||||
165
internal/pushrules/evaluate.go
Normal file
165
internal/pushrules/evaluate.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A RuleSetEvaluator encapsulates context to evaluate an event
|
||||
// against a rule set.
|
||||
type RuleSetEvaluator struct {
|
||||
ec EvaluationContext
|
||||
ruleSet []kindAndRules
|
||||
}
|
||||
|
||||
// An EvaluationContext gives a RuleSetEvaluator access to the
|
||||
// environment, for rules that require that.
|
||||
type EvaluationContext interface {
|
||||
// UserDisplayName returns the current user's display name.
|
||||
UserDisplayName() string
|
||||
|
||||
// RoomMemberCount returns the number of members in the room of
|
||||
// the current event.
|
||||
RoomMemberCount() (int, error)
|
||||
|
||||
// HasPowerLevel returns whether the user has at least the given
|
||||
// power in the room of the current event.
|
||||
HasPowerLevel(userID, levelKey string) (bool, error)
|
||||
}
|
||||
|
||||
// A kindAndRules is just here to simplify iteration of the (ordered)
|
||||
// kinds of rules.
|
||||
type kindAndRules struct {
|
||||
Kind Kind
|
||||
Rules []*Rule
|
||||
}
|
||||
|
||||
// NewRuleSetEvaluator creates a new evaluator for the given rule set.
|
||||
func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
|
||||
return &RuleSetEvaluator{
|
||||
ec: ec,
|
||||
ruleSet: []kindAndRules{
|
||||
{OverrideKind, ruleSet.Override},
|
||||
{ContentKind, ruleSet.Content},
|
||||
{RoomKind, ruleSet.Room},
|
||||
{SenderKind, ruleSet.Sender},
|
||||
{UnderrideKind, ruleSet.Underride},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// MatchEvent returns the first matching rule. Returns nil if there
|
||||
// was no match rule.
|
||||
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
|
||||
// TODO: server-default rules have lower priority than user rules,
|
||||
// but they are stored together with the user rules. It's a bit
|
||||
// unclear what the specification (11.14.1.4 Predefined rules)
|
||||
// means the ordering should be.
|
||||
//
|
||||
// The most reasonable interpretation is that default overrides
|
||||
// still have lower priority than user content rules, so we
|
||||
// iterate twice.
|
||||
for _, rsat := range rse.ruleSet {
|
||||
for _, defRules := range []bool{false, true} {
|
||||
for _, rule := range rsat.Rules {
|
||||
if rule.Default != defRules {
|
||||
continue
|
||||
}
|
||||
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||
if !rule.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case OverrideKind, UnderrideKind:
|
||||
for _, cond := range rule.Conditions {
|
||||
ok, err := conditionMatches(cond, event, ec)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
|
||||
case ContentKind:
|
||||
// TODO: "These configure behaviour for (unencrypted) messages
|
||||
// that match certain patterns." - Does that mean "content.body"?
|
||||
return patternMatches("content.body", rule.Pattern, event)
|
||||
|
||||
case RoomKind:
|
||||
return rule.RuleID == event.RoomID(), nil
|
||||
|
||||
case SenderKind:
|
||||
return rule.RuleID == event.Sender(), nil
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||
switch cond.Kind {
|
||||
case EventMatchCondition:
|
||||
return patternMatches(cond.Key, cond.Pattern, event)
|
||||
|
||||
case ContainsDisplayNameCondition:
|
||||
return patternMatches("content.body", ec.UserDisplayName(), event)
|
||||
|
||||
case RoomMemberCountCondition:
|
||||
cmp, err := parseRoomMemberCountCondition(cond.Is)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing room_member_count condition: %w", err)
|
||||
}
|
||||
n, err := ec.RoomMemberCount()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("RoomMemberCount failed: %w", err)
|
||||
}
|
||||
return cmp(n), nil
|
||||
|
||||
case SenderNotificationPermissionCondition:
|
||||
return ec.HasPowerLevel(event.Sender(), cond.Key)
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
|
||||
re, err := globToRegexp(pattern)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var eventMap map[string]interface{}
|
||||
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
|
||||
return false, fmt.Errorf("parsing event: %w", err)
|
||||
}
|
||||
v, err := lookupMapPath(strings.Split(key, "."), eventMap)
|
||||
if err != nil {
|
||||
// An unknown path is a benign error that shouldn't stop rule
|
||||
// processing. It's just a non-match.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return re.MatchString(fmt.Sprint(v)), nil
|
||||
}
|
||||
189
internal/pushrules/evaluate_test.go
Normal file
189
internal/pushrules/evaluate_test.go
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
||||
ev := mustEventFromJSON(t, `{}`)
|
||||
defaultEnabled := &Rule{
|
||||
RuleID: ".default.enabled",
|
||||
Default: true,
|
||||
Enabled: true,
|
||||
}
|
||||
userEnabled := &Rule{
|
||||
RuleID: ".user.enabled",
|
||||
Default: false,
|
||||
Enabled: true,
|
||||
}
|
||||
userEnabled2 := &Rule{
|
||||
RuleID: ".user.enabled.2",
|
||||
Default: false,
|
||||
Enabled: true,
|
||||
}
|
||||
tsts := []struct {
|
||||
Name string
|
||||
RuleSet RuleSet
|
||||
Want *Rule
|
||||
}{
|
||||
{"empty", RuleSet{}, nil},
|
||||
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
|
||||
{"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
|
||||
{"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
|
||||
{"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
|
||||
{"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
|
||||
got, err := rse.MatchEvent(ev)
|
||||
if err != nil {
|
||||
t.Fatalf("MatchEvent failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches(t *testing.T) {
|
||||
emptyRule := Rule{Enabled: true}
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
|
||||
{"emptyContent", ContentKind, emptyRule, `{}`, false},
|
||||
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
|
||||
{"emptySender", SenderKind, emptyRule, `{}`, true},
|
||||
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
|
||||
|
||||
{"disabled", OverrideKind, Rule{}, `{}`, false},
|
||||
|
||||
{"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
|
||||
{"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||
|
||||
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
|
||||
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||
|
||||
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
|
||||
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
|
||||
|
||||
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
|
||||
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
|
||||
|
||||
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
|
||||
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ruleMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionMatches(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Cond Condition
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"empty", Condition{}, `{}`, false},
|
||||
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
|
||||
|
||||
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
|
||||
|
||||
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
|
||||
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
|
||||
|
||||
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
|
||||
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
|
||||
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
|
||||
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
|
||||
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
|
||||
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
|
||||
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
|
||||
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
|
||||
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
|
||||
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
|
||||
|
||||
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
|
||||
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("conditionMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeEvaluationContext struct{}
|
||||
|
||||
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
|
||||
func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
|
||||
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
|
||||
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
|
||||
}
|
||||
|
||||
func TestPatternMatches(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Key string
|
||||
Pattern string
|
||||
EventJSON string
|
||||
Want bool
|
||||
}{
|
||||
{"empty", "", "", `{}`, false},
|
||||
|
||||
// Note that an empty pattern contains no wildcard characters,
|
||||
// which implicitly means "*".
|
||||
{"patternEmpty", "content", "", `{"content":{}}`, true},
|
||||
|
||||
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
|
||||
{"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("patternMatches failed: %v", err)
|
||||
}
|
||||
if got != tst.Want {
|
||||
t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
|
||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return ev
|
||||
}
|
||||
71
internal/pushrules/pushrules.go
Normal file
71
internal/pushrules/pushrules.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package pushrules
|
||||
|
||||
// An AccountRuleSets carries the rule sets associated with an
|
||||
// account.
|
||||
type AccountRuleSets struct {
|
||||
Global RuleSet `json:"global"` // Required
|
||||
}
|
||||
|
||||
// A RuleSet contains all the various push rules for an
|
||||
// account. Listed in decreasing order of priority.
|
||||
type RuleSet struct {
|
||||
Override []*Rule `json:"override,omitempty"`
|
||||
Content []*Rule `json:"content,omitempty"`
|
||||
Room []*Rule `json:"room,omitempty"`
|
||||
Sender []*Rule `json:"sender,omitempty"`
|
||||
Underride []*Rule `json:"underride,omitempty"`
|
||||
}
|
||||
|
||||
// A Rule contains matchers, conditions and final actions. While
|
||||
// evaluating, at most one rule is considered matching.
|
||||
//
|
||||
// Kind and scope are part of the push rules request/responses, but
|
||||
// not of the core data model.
|
||||
type Rule struct {
|
||||
// RuleID is either a free identifier, or the sender's MXID for
|
||||
// SenderKind. Required.
|
||||
RuleID string `json:"rule_id"`
|
||||
|
||||
// Default indicates whether this is a server-defined default, or
|
||||
// a user-provided rule. Required.
|
||||
//
|
||||
// The server-default rules have the lowest priority.
|
||||
Default bool `json:"default"`
|
||||
|
||||
// Enabled allows the user to disable rules while keeping them
|
||||
// around. Required.
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Actions describe the desired outcome, should the rule
|
||||
// match. Required.
|
||||
Actions []*Action `json:"actions"`
|
||||
|
||||
// Conditions provide the rule's conditions for OverrideKind and
|
||||
// UnderrideKind. Not allowed for other kinds.
|
||||
Conditions []*Condition `json:"conditions"`
|
||||
|
||||
// Pattern is the body pattern to match for ContentKind. Required
|
||||
// for that kind. The interpretation is the same as that of
|
||||
// Condition.Pattern.
|
||||
Pattern string `json:"pattern"`
|
||||
}
|
||||
|
||||
// Scope only has one valid value. See also AccountRuleSets.
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
UnknownScope Scope = ""
|
||||
GlobalScope Scope = "global"
|
||||
)
|
||||
|
||||
// Kind is the type of push rule. See also RuleSet.
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
UnknownKind Kind = ""
|
||||
OverrideKind Kind = "override"
|
||||
ContentKind Kind = "content"
|
||||
RoomKind Kind = "room"
|
||||
SenderKind Kind = "sender"
|
||||
UnderrideKind Kind = "underride"
|
||||
)
|
||||
125
internal/pushrules/util.go
Normal file
125
internal/pushrules/util.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ActionsToTweaks converts a list of actions into a primary action
|
||||
// kind and a tweaks map. Returns a nil map if it would have been
|
||||
// empty.
|
||||
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
|
||||
var kind ActionKind
|
||||
tweaks := map[string]interface{}{}
|
||||
|
||||
for _, a := range as {
|
||||
if a.Kind == SetTweakAction {
|
||||
tweaks[string(a.Tweak)] = a.Value
|
||||
continue
|
||||
}
|
||||
if kind != UnknownAction {
|
||||
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
|
||||
}
|
||||
kind = a.Kind
|
||||
}
|
||||
|
||||
if len(tweaks) == 0 {
|
||||
tweaks = nil
|
||||
}
|
||||
|
||||
return kind, tweaks, nil
|
||||
}
|
||||
|
||||
// BoolTweakOr returns the named tweak as a boolean, and returns `def`
|
||||
// on failure.
|
||||
func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
|
||||
v, ok := tweaks[string(key)]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
|
||||
func globToRegexp(pattern string) (*regexp.Regexp, error) {
|
||||
// TODO: It's unclear which glob characters are supported. The only
|
||||
// place this is discussed is for the unrelated "m.policy.rule.*"
|
||||
// events. Assuming, the same: /[*?]/
|
||||
if !strings.ContainsAny(pattern, "*?") {
|
||||
pattern = "*" + pattern + "*"
|
||||
}
|
||||
|
||||
// The defined syntax doesn't allow escaping the glob wildcard
|
||||
// characters, which makes this a straight-forward
|
||||
// replace-after-quote.
|
||||
pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
|
||||
pattern = strings.Replace(pattern, "*", ".*", -1)
|
||||
pattern = strings.Replace(pattern, "?", ".", -1)
|
||||
return regexp.Compile("^(" + pattern + ")$")
|
||||
}
|
||||
|
||||
// globNonMetaRegexp are the characters that are not considered glob
|
||||
// meta-characters (i.e. may need escaping).
|
||||
var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
|
||||
|
||||
// lookupMapPath traverses a hierarchical map structure, like the one
|
||||
// produced by json.Unmarshal, to return the leaf value. Traversing
|
||||
// arrays/slices is not supported, only objects/maps.
|
||||
func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty path")
|
||||
}
|
||||
|
||||
var v interface{} = m
|
||||
for i, key := range path {
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
|
||||
}
|
||||
|
||||
v, ok = m[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
|
||||
// into a function that checks if the argument to it fulfils the
|
||||
// condition.
|
||||
func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
|
||||
var b int
|
||||
var cmp = func(a int) bool { return a == b }
|
||||
switch {
|
||||
case strings.HasPrefix(s, "<="):
|
||||
cmp = func(a int) bool { return a <= b }
|
||||
s = s[2:]
|
||||
case strings.HasPrefix(s, ">="):
|
||||
cmp = func(a int) bool { return a >= b }
|
||||
s = s[2:]
|
||||
case strings.HasPrefix(s, "<"):
|
||||
cmp = func(a int) bool { return a < b }
|
||||
s = s[1:]
|
||||
case strings.HasPrefix(s, ">"):
|
||||
cmp = func(a int) bool { return a > b }
|
||||
s = s[1:]
|
||||
case strings.HasPrefix(s, "=="):
|
||||
// Same cmp as the default.
|
||||
s = s[2:]
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b = int(v)
|
||||
return cmp, nil
|
||||
}
|
||||
169
internal/pushrules/util_test.go
Normal file
169
internal/pushrules/util_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestActionsToTweaks(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Input []*Action
|
||||
WantKind ActionKind
|
||||
WantTweaks map[string]interface{}
|
||||
}{
|
||||
{"empty", nil, UnknownAction, nil},
|
||||
{"zero", []*Action{{}}, UnknownAction, nil},
|
||||
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
|
||||
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
|
||||
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
|
||||
{
|
||||
"all",
|
||||
[]*Action{
|
||||
{Kind: CoalesceAction},
|
||||
{Kind: SetTweakAction, Tweak: HighlightTweak},
|
||||
{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
|
||||
},
|
||||
CoalesceAction,
|
||||
map[string]interface{}{"highlight": nil, "sound": "default"},
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("ActionsToTweaks failed: %v", err)
|
||||
}
|
||||
if gotKind != tst.WantKind {
|
||||
t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
|
||||
}
|
||||
if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
|
||||
t.Errorf("tweaks: +got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolTweakOr(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Input map[string]interface{}
|
||||
Def bool
|
||||
Want bool
|
||||
}{
|
||||
{"nil", nil, false, false},
|
||||
{"nilValue", map[string]interface{}{"highlight": nil}, true, true},
|
||||
{"false", map[string]interface{}{"highlight": false}, true, false},
|
||||
{"true", map[string]interface{}{"highlight": true}, false, true},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
|
||||
if got != tst.Want {
|
||||
t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobToRegexp(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Input string
|
||||
Want string
|
||||
}{
|
||||
{"", "^(.*.*)$"},
|
||||
{"a", "^(.*a.*)$"},
|
||||
{"a.b", "^(.*a\\.b.*)$"},
|
||||
{"a?b", "^(a.b)$"},
|
||||
{"a*b*", "^(a.*b.*)$"},
|
||||
{"a*b?", "^(a.*b.)$"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Want, func(t *testing.T) {
|
||||
got, err := globToRegexp(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("globToRegexp failed: %v", err)
|
||||
}
|
||||
if got.String() != tst.Want {
|
||||
t.Errorf("got %v, want %v", got.String(), tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupMapPath(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Path []string
|
||||
Root map[string]interface{}
|
||||
Want interface{}
|
||||
}{
|
||||
{[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
|
||||
{[]string{"a"}, map[string]interface{}{"a": 42}, 42},
|
||||
{[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
|
||||
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||
if err != nil {
|
||||
t.Fatalf("lookupMapPath failed: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tst.Want, got); diff != "" {
|
||||
t.Errorf("+got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupMapPathInvalid(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Path []string
|
||||
Root map[string]interface{}
|
||||
}{
|
||||
{nil, nil},
|
||||
{[]string{"a"}, nil},
|
||||
{[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
|
||||
got, err := lookupMapPath(tst.Path, tst.Root)
|
||||
if err == nil {
|
||||
t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRoomMemberCountCondition(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Input string
|
||||
WantTrue []int
|
||||
WantFalse []int
|
||||
}{
|
||||
{"1", []int{1}, []int{0, 2}},
|
||||
{"==1", []int{1}, []int{0, 2}},
|
||||
{"<1", []int{0}, []int{1, 2}},
|
||||
{"<=1", []int{0, 1}, []int{2}},
|
||||
{">1", []int{2}, []int{0, 1}},
|
||||
{">=42", []int{42, 43}, []int{41}},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Input, func(t *testing.T) {
|
||||
got, err := parseRoomMemberCountCondition(tst.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
|
||||
}
|
||||
for _, v := range tst.WantTrue {
|
||||
if !got(v) {
|
||||
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
|
||||
}
|
||||
}
|
||||
for _, v := range tst.WantFalse {
|
||||
if got(v) {
|
||||
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
85
internal/pushrules/validate.go
Normal file
85
internal/pushrules/validate.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// ValidateRule checks the rule for errors. These follow from Sytests
|
||||
// and the specification.
|
||||
func ValidateRule(kind Kind, rule *Rule) []error {
|
||||
var errs []error
|
||||
|
||||
if !validRuleIDRE.MatchString(rule.RuleID) {
|
||||
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
|
||||
}
|
||||
|
||||
if len(rule.Actions) == 0 {
|
||||
errs = append(errs, fmt.Errorf("missing actions"))
|
||||
}
|
||||
for _, action := range rule.Actions {
|
||||
errs = append(errs, validateAction(action)...)
|
||||
}
|
||||
|
||||
for _, cond := range rule.Conditions {
|
||||
errs = append(errs, validateCondition(cond)...)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case OverrideKind, UnderrideKind:
|
||||
// The empty list is allowed, but for JSON-encoding reasons,
|
||||
// it must not be nil.
|
||||
if rule.Conditions == nil {
|
||||
errs = append(errs, fmt.Errorf("missing rule conditions"))
|
||||
}
|
||||
|
||||
case ContentKind:
|
||||
if rule.Pattern == "" {
|
||||
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
||||
}
|
||||
|
||||
case RoomKind, SenderKind:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// validRuleIDRE is a regexp for valid IDs.
|
||||
//
|
||||
// TODO: the specification doesn't seem to say what the rule ID syntax
|
||||
// is. A Sytest fails if it contains a backslash.
|
||||
var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
|
||||
|
||||
// validateAction returns issues with an Action.
|
||||
func validateAction(action *Action) []error {
|
||||
var errs []error
|
||||
|
||||
switch action.Kind {
|
||||
case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// validateCondition returns issues with a Condition.
|
||||
func validateCondition(cond *Condition) []error {
|
||||
var errs []error
|
||||
|
||||
switch cond.Kind {
|
||||
case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
|
||||
// Do nothing.
|
||||
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
163
internal/pushrules/validate_test.go
Normal file
163
internal/pushrules/validate_test.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package pushrules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateRuleNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
|
||||
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
|
||||
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
|
||||
{"noActions", OverrideKind, Rule{}, "missing actions"},
|
||||
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
|
||||
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
|
||||
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
|
||||
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
|
||||
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRulePositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Rule Rule
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
|
||||
{"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
|
||||
{"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
|
||||
{"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
|
||||
{"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
|
||||
{"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
|
||||
{"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
|
||||
{"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := ValidateRule(tst.Kind, &tst.Rule)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateActionNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Action Action
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyKind", Action{}, "invalid rule action kind"},
|
||||
{"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateAction(&tst.Action)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateActionPositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Action Action
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateAction(&tst.Action)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConditionNegatives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Condition Condition
|
||||
WantErrString string
|
||||
}{
|
||||
{"emptyKind", Condition{}, "invalid rule condition kind"},
|
||||
{"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateCondition(&tst.Condition)
|
||||
var foundErr error
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantErrString) {
|
||||
foundErr = err
|
||||
}
|
||||
}
|
||||
if foundErr == nil {
|
||||
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConditionPositives(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Condition Condition
|
||||
WantNoErrString string
|
||||
}{
|
||||
{"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
errs := validateCondition(&tst.Condition)
|
||||
for _, err := range errs {
|
||||
t.Logf("Got error %#v", err)
|
||||
if strings.Contains(err.Error(), tst.WantNoErrString) {
|
||||
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -163,6 +163,7 @@ type StatementList []struct {
|
|||
func (s StatementList) Prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
|
||||
err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ var build string
|
|||
const (
|
||||
VersionMajor = 0
|
||||
VersionMinor = 6
|
||||
VersionPatch = 4
|
||||
VersionPatch = 5
|
||||
VersionTag = "" // example: "rc1"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
// We can't have a self-signing or user-signing key without a master
|
||||
// key, so make sure we have one of those.
|
||||
if !hasMasterKey {
|
||||
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||
}
|
||||
return
|
||||
// key, so make sure we have one of those. We will also only actually do
|
||||
// something if any of the specified keys in the request are different
|
||||
// to what we've got in the database, to avoid generating key change
|
||||
// notifications unnecessarily.
|
||||
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||
}
|
||||
|
||||
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
|
||||
return
|
||||
}
|
||||
|
||||
// If we still can't find a master key for the user then stop the upload.
|
||||
// This satisfies the "Fails to upload self-signing key without master key" test.
|
||||
if !hasMasterKey {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No master key was found",
|
||||
IsMissingParam: true,
|
||||
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No master key was found",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if anything actually changed compared to what we have in the database.
|
||||
changed := false
|
||||
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
|
||||
gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||
} {
|
||||
old, gotOld := existingKeys[purpose]
|
||||
new, gotNew := toStore[purpose]
|
||||
if gotOld != gotNew {
|
||||
// A new key purpose has been specified that we didn't know before,
|
||||
// or one has been removed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
if !bytes.Equal(old, new) {
|
||||
// One of the existing keys for a purpose we already knew about has
|
||||
// changed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct {
|
|||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
var result []string
|
||||
for userID := range d.staleUsers {
|
||||
for userID, isStale := range d.staleUsers {
|
||||
if !isStale {
|
||||
continue
|
||||
}
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -75,10 +81,18 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do
|
|||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.staleUsers[userID] = isStale
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.staleUsers[userID]
|
||||
}
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||
|
|
@ -161,7 +175,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
if db.staleUsers[event.UserID] {
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||
}
|
||||
}
|
||||
|
|
@ -235,7 +249,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
},
|
||||
}
|
||||
// Now we should have a fresh list and the keys and emitted something
|
||||
if db.staleUsers[event.UserID] {
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s still marked as stale", event.UserID)
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
|
|
@ -247,3 +261,83 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
|
||||
// update is still ongoing.
|
||||
func TestDebounce(t *testing.T) {
|
||||
t.Skipf("panic on closed channel on GHA")
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
fedCh := make(chan *http.Response, 1)
|
||||
srv := gomatrixserverlib.ServerName("example.com")
|
||||
userID := "@alice:example.com"
|
||||
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
incomingFedReq := make(chan struct{})
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
close(incomingFedReq)
|
||||
return <-fedCh, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
||||
// hit this 5 times
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
|
||||
t.Errorf("ManualUpdate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until the updater hits federation
|
||||
select {
|
||||
case <-incomingFedReq:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timed out waiting for updater to hit federation")
|
||||
}
|
||||
|
||||
// user should be marked as stale
|
||||
if !db.isStale(userID) {
|
||||
t.Errorf("user %s not marked as stale", userID)
|
||||
}
|
||||
// now send the response over federation
|
||||
fedCh <- &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + userID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}
|
||||
close(fedCh)
|
||||
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
|
||||
// and should panic with read on a closed channel
|
||||
wg.Wait()
|
||||
|
||||
// user is no longer stale now
|
||||
if db.isStale(userID) {
|
||||
t.Errorf("user %s is marked as stale", userID)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ func Setup(
|
|||
) {
|
||||
rateLimits := httputil.NewRateLimits(rateLimit)
|
||||
|
||||
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
|
||||
v1mux := publicAPIMux.PathPrefix("/v1").Subrouter()
|
||||
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
|
||||
|
||||
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
||||
PathToResult: map[string]*types.ThumbnailGenerationResult{},
|
||||
|
|
@ -80,21 +79,18 @@ func Setup(
|
|||
}
|
||||
})
|
||||
|
||||
r0mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
v1mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
activeRemoteRequests := &types.ActiveRemoteRequests{
|
||||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||
}
|
||||
|
||||
downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
|
||||
r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
|
||||
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed
|
||||
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
|
@ -124,13 +120,16 @@ func makeDownloadAPI(
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Ratelimit requests
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
if err := json.NewEncoder(w).Encode(r); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
// NOTSPEC: The spec says everything at /media/ should be rate limited, but this causes issues with thumbnails (#2243)
|
||||
if name != "thumbnail" {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
if err := json.NewEncoder(w).Encode(r); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
vars, _ := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
package api
|
||||
|
||||
import "regexp"
|
||||
|
||||
// SetRoomAliasRequest is a request to SetRoomAlias
|
||||
type SetRoomAliasRequest struct {
|
||||
// ID of the user setting the alias
|
||||
|
|
@ -84,3 +86,20 @@ type RemoveRoomAliasResponse struct {
|
|||
// Did we remove it?
|
||||
Removed bool `json:"removed"`
|
||||
}
|
||||
|
||||
type AliasEvent struct {
|
||||
Alias string `json:"alias"`
|
||||
AltAliases []string `json:"alt_aliases"`
|
||||
}
|
||||
|
||||
var validateAliasRegex = regexp.MustCompile("^#.*:.+$")
|
||||
|
||||
func (a AliasEvent) Valid() bool {
|
||||
for _, alias := range a.AltAliases {
|
||||
if !validateAliasRegex.MatchString(alias) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return a.Alias == "" || validateAliasRegex.MatchString(a.Alias)
|
||||
}
|
||||
|
||||
|
|
|
|||
62
roomserver/api/alias_test.go
Normal file
62
roomserver/api/alias_test.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package api
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAliasEvent_Valid(t *testing.T) {
|
||||
type fields struct {
|
||||
Alias string
|
||||
AltAliases []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty alias",
|
||||
fields: fields{
|
||||
Alias: "",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty alias, invalid alt aliases",
|
||||
fields: fields{
|
||||
Alias: "",
|
||||
AltAliases: []string{ "%not:valid.local"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid alias, invalid alt aliases",
|
||||
fields: fields{
|
||||
Alias: "#valid:test.local",
|
||||
AltAliases: []string{ "%not:valid.local"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty alias, invalid alt aliases",
|
||||
fields: fields{
|
||||
Alias: "",
|
||||
AltAliases: []string{ "%not:valid.local"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid alias",
|
||||
fields: fields{
|
||||
Alias: "%not:valid.local",
|
||||
AltAliases: []string{ },
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AliasEvent{
|
||||
Alias: tt.fields.Alias,
|
||||
AltAliases: tt.fields.AltAliases,
|
||||
}
|
||||
if got := a.Valid(); got != tt.want {
|
||||
t.Errorf("Valid() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -105,7 +105,7 @@ type OutputNewRoomEvent struct {
|
|||
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
|
||||
// Does the event completely rewrite the room state? If so, then AddsStateEventIDs
|
||||
// will contain the entire room state.
|
||||
RewritesState bool `json:"rewrites_state"`
|
||||
RewritesState bool `json:"rewrites_state,omitempty"`
|
||||
// The latest events in the room after this event.
|
||||
// This can be used to set the prev events for new events in the room.
|
||||
// This also can be used to get the full current state after this event.
|
||||
|
|
@ -113,16 +113,9 @@ type OutputNewRoomEvent struct {
|
|||
// The state event IDs that were added to the state of the room by this event.
|
||||
// Together with RemovesStateEventIDs this allows the receiver to keep an up to date
|
||||
// view of the current state of the room.
|
||||
AddsStateEventIDs []string `json:"adds_state_event_ids"`
|
||||
// All extra newly added state events. This is only set if there are *extra* events
|
||||
// other than `Event`. This can happen when forks get merged because state resolution
|
||||
// may decide a bunch of state events on one branch are now valid, so they will be
|
||||
// present in this list. This is useful when trying to maintain the current state of a room
|
||||
// as to do so you need to include both these events and `Event`.
|
||||
AddStateEvents []*gomatrixserverlib.HeaderedEvent `json:"adds_state_events"`
|
||||
|
||||
AddsStateEventIDs []string `json:"adds_state_event_ids,omitempty"`
|
||||
// The state event IDs that were removed from the state of the room by this event.
|
||||
RemovesStateEventIDs []string `json:"removes_state_event_ids"`
|
||||
RemovesStateEventIDs []string `json:"removes_state_event_ids,omitempty"`
|
||||
// The ID of the event that was output before this event.
|
||||
// Or the empty string if this is the first event output for this room.
|
||||
// This is used by consumers to check if they can safely update their
|
||||
|
|
@ -145,10 +138,10 @@ type OutputNewRoomEvent struct {
|
|||
//
|
||||
// The state is given as a delta against the current state because they are
|
||||
// usually either the same state, or differ by just a couple of events.
|
||||
StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids"`
|
||||
StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids,omitempty"`
|
||||
// The state event IDs that are part of the current state, but not part
|
||||
// of the state at the event.
|
||||
StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids"`
|
||||
StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids,omitempty"`
|
||||
// The server name to use to push this event to other servers.
|
||||
// Or empty if this event shouldn't be pushed to other servers.
|
||||
//
|
||||
|
|
@ -167,27 +160,7 @@ type OutputNewRoomEvent struct {
|
|||
SendAsServer string `json:"send_as_server"`
|
||||
// The transaction ID of the send request if sent by a local user and one
|
||||
// was specified
|
||||
TransactionID *TransactionID `json:"transaction_id"`
|
||||
}
|
||||
|
||||
// AddsState returns all added state events from this event.
|
||||
//
|
||||
// This function is needed because `AddStateEvents` will not include a copy of
|
||||
// the original event to save space, so you cannot use that slice alone.
|
||||
// Instead, use this function which will add the original event if it is present
|
||||
// in `AddsStateEventIDs`.
|
||||
func (ore *OutputNewRoomEvent) AddsState() []*gomatrixserverlib.HeaderedEvent {
|
||||
includeOutputEvent := false
|
||||
for _, id := range ore.AddsStateEventIDs {
|
||||
if id == ore.Event.EventID() {
|
||||
includeOutputEvent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !includeOutputEvent {
|
||||
return ore.AddStateEvents
|
||||
}
|
||||
return append(ore.AddStateEvents, ore.Event)
|
||||
TransactionID *TransactionID `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
// An OutputOldRoomEvent is written when the roomserver receives an old event.
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
|
|||
|
||||
type QuerySharedUsersRequest struct {
|
||||
UserID string
|
||||
OtherUserIDs []string
|
||||
ExcludeRoomIDs []string
|
||||
IncludeRoomIDs []string
|
||||
}
|
||||
|
|
@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct {
|
|||
}
|
||||
|
||||
type QueryCurrentStateRequest struct {
|
||||
RoomID string
|
||||
RoomID string
|
||||
AllowWildcards bool
|
||||
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
|
||||
// state events with that event type.
|
||||
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,12 +51,8 @@ func SendEventWithState(
|
|||
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
|
||||
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
||||
) error {
|
||||
outliers, err := state.Events(event.RoomVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ires []InputRoomEvent
|
||||
outliers := state.Events(event.RoomVersion)
|
||||
ires := make([]InputRoomEvent, 0, len(outliers))
|
||||
for _, outlier := range outliers {
|
||||
if haveEventIDs[outlier.EventID()] {
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -16,12 +16,18 @@ package internal
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"time"
|
||||
|
||||
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API.
|
||||
|
|
@ -183,6 +189,57 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
|||
}
|
||||
}
|
||||
|
||||
ev, err := r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomCanonicalAlias, "")
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
} else if ev != nil {
|
||||
stateAlias := gjson.GetBytes(ev.Content(), "alias").Str
|
||||
// the alias to remove is currently set as the canonical alias, remove it
|
||||
if stateAlias == request.Alias {
|
||||
res, err := sjson.DeleteBytes(ev.Content(), "alias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sender := request.UserID
|
||||
if request.UserID != ev.Sender() {
|
||||
sender = ev.Sender()
|
||||
}
|
||||
|
||||
builder := &gomatrixserverlib.EventBuilder{
|
||||
Sender: sender,
|
||||
RoomID: ev.RoomID(),
|
||||
Type: ev.Type(),
|
||||
StateKey: ev.StateKey(),
|
||||
Content: res,
|
||||
}
|
||||
|
||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
|
||||
}
|
||||
if len(eventsNeeded.Tuples()) == 0 {
|
||||
return errors.New("expecting state tuples for event builder, got none")
|
||||
}
|
||||
|
||||
stateRes := &api.QueryLatestEventsAndStateResponse{}
|
||||
if err := helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = api.SendEvents(ctx, r.RSAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the alias from the database
|
||||
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
|
||||
// Store the event.
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected || softfail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -365,6 +365,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
|||
LastSentEventID: u.lastEventIDSent,
|
||||
LatestEventIDs: latestEventIDs,
|
||||
TransactionID: u.transactionID,
|
||||
SendAsServer: u.sendAsServer,
|
||||
}
|
||||
|
||||
eventIDMap, err := u.stateEventMap()
|
||||
|
|
@ -384,51 +385,17 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
|||
ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
|
||||
}
|
||||
|
||||
ore.SendAsServer = u.sendAsServer
|
||||
|
||||
// include extra state events if they were added as nearly every downstream component will care about it
|
||||
// and we'd rather not have them all hit QueryEventsByID at the same time!
|
||||
if len(ore.AddsStateEventIDs) > 0 {
|
||||
var err error
|
||||
if ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs); err != nil {
|
||||
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.OutputEvent{
|
||||
Type: api.OutputTypeNewRoomEvent,
|
||||
NewRoomEvent: &ore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being
|
||||
// updated.
|
||||
func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
var extraEventIDs []string
|
||||
for _, e := range eventIDs {
|
||||
if e == u.event.EventID() {
|
||||
continue
|
||||
}
|
||||
extraEventIDs = append(extraEventIDs, e)
|
||||
}
|
||||
if len(extraEventIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var h []*gomatrixserverlib.HeaderedEvent
|
||||
for _, e := range extraEvents {
|
||||
h = append(h, e.Headered(roomVersion))
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// retrieve an event nid -> event ID map for all events that need updating
|
||||
func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) {
|
||||
var stateEventNIDs []types.EventNID
|
||||
var allStateEntries []types.StateEntry
|
||||
cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds)
|
||||
stateEventNIDs := make(types.EventNIDs, 0, cap)
|
||||
allStateEntries := make([]types.StateEntry, 0, cap)
|
||||
allStateEntries = append(allStateEntries, u.added...)
|
||||
allStateEntries = append(allStateEntries, u.removed...)
|
||||
allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...)
|
||||
|
|
@ -436,12 +403,6 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
|
|||
for _, entry := range allStateEntries {
|
||||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||
}
|
||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)]
|
||||
return u.updater.EventIDs(u.ctx, stateEventNIDs)
|
||||
}
|
||||
|
||||
type eventNIDSorter []types.EventNID
|
||||
|
||||
func (s eventNIDSorter) Len() int { return len(s) }
|
||||
func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
|
|
|||
|
|
@ -23,6 +23,21 @@ type parsedRespState struct {
|
|||
StateEvents []*gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
|
||||
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
|
||||
for i, event := range p.AuthEvents {
|
||||
eventsByID[event.EventID()] = p.AuthEvents[i]
|
||||
}
|
||||
for i, event := range p.StateEvents {
|
||||
eventsByID[event.EventID()] = p.StateEvents[i]
|
||||
}
|
||||
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
|
||||
for _, event := range eventsByID {
|
||||
allEvents = append(allEvents, event)
|
||||
}
|
||||
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
|
||||
}
|
||||
|
||||
type missingStateReq struct {
|
||||
origin gomatrixserverlib.ServerName
|
||||
db storage.Database
|
||||
|
|
@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
t.hadEventsMutex.Unlock()
|
||||
|
||||
sendOutliers := func(resolvedState *parsedRespState) error {
|
||||
outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion)
|
||||
if oerr != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr)
|
||||
}
|
||||
var outlierRoomEvents []api.InputRoomEvent
|
||||
outliers := resolvedState.Events()
|
||||
outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers))
|
||||
for _, outlier := range outliers {
|
||||
if hadEvents[outlier.EventID()] {
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms(
|
|||
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||
for _, tuple := range req.StateTuples {
|
||||
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ev != nil {
|
||||
res.StateEvents[tuple] = ev
|
||||
if tuple.StateKey == "*" && req.AllowWildcards {
|
||||
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, e := range events {
|
||||
res.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||
EventType: e.Type(),
|
||||
StateKey: *e.StateKey(),
|
||||
}] = e
|
||||
}
|
||||
} else {
|
||||
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ev != nil {
|
||||
res.StateEvents[tuple] = ev
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
|
|||
}
|
||||
roomIDs = roomIDs[:j]
|
||||
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,13 +146,14 @@ type Database interface {
|
|||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
|
||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
|
|
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNIDs []types.RoomNID,
|
||||
userNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
roomIDarray := make([]int64, len(roomNIDs))
|
||||
for i := range roomNIDs {
|
||||
roomIDarray[i] = int64(roomNIDs[i])
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
|
|
@ -979,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
|
||||
// if there are no events with this event type.
|
||||
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return nil, fmt.Errorf("room %s doesn't exist", roomID)
|
||||
}
|
||||
// e.g invited rooms
|
||||
if roomInfo.IsStub {
|
||||
return nil, nil
|
||||
}
|
||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||
if err == sql.ErrNoRows {
|
||||
// No rooms have an event of this type, otherwise we'd have an event type NID
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var eventNIDs []types.EventNID
|
||||
for _, e := range entries {
|
||||
if e.EventTypeNID == eventTypeNID {
|
||||
eventNIDs = append(eventNIDs, e.EventNID)
|
||||
}
|
||||
}
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
eventIDs = map[types.EventNID]string{}
|
||||
}
|
||||
// return the events requested
|
||||
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(eventPairs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var result []*gomatrixserverlib.HeaderedEvent
|
||||
for _, pair := range eventPairs {
|
||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, ev.Headered(roomInfo.RoomVersion))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||
var membershipState tables.MembershipState
|
||||
|
|
@ -1104,13 +1159,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
|
||||
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
|
||||
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
|
||||
for id, nid := range userNIDsMap {
|
||||
userNIDs = append(userNIDs, nid)
|
||||
nidToUserID[nid] = id
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1120,13 +1185,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nidToUserID) != len(userNIDToCount) {
|
||||
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
|
||||
}
|
||||
result := make(map[string]int, len(userNIDToCount))
|
||||
for nid, count := range userNIDToCount {
|
||||
result[nidToUserID[nid]] = count
|
||||
|
|
|
|||
|
|
@ -154,6 +154,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer selectPrep.Close()
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer selectPrep.Close()
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
|
|
|
|||
|
|
@ -198,11 +198,12 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
|
|
@ -266,11 +267,12 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
}
|
||||
}
|
||||
selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := selectStmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
||||
|
|
@ -307,11 +309,12 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
|||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
|
@ -390,10 +393,11 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectPrep = sqlutil.TxStmt(txn, selectPrep)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
//////////////
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||
rows, err := sqlutil.TxStmt(txn, selectStmt).QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
|
||||
}
|
||||
|
|
@ -441,6 +445,7 @@ func (s *eventStatements) BulkSelectEventReference(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
///////////////
|
||||
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
|
|
@ -471,11 +476,12 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
|
|||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
|
|
@ -526,11 +532,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
|
|||
} else {
|
||||
selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
}
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
|
@ -560,6 +567,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer sqlPrep.Close()
|
||||
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
|
||||
|
|
@ -575,12 +583,13 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
defer sqlPrep.Close()
|
||||
sqlStmt := sqlutil.TxStmt(txn, sqlPrep)
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventNIDs[i] = v
|
||||
}
|
||||
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
|
||||
rows, err := sqlStmt.QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ const membershipSchema = `
|
|||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
|
|
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
|
||||
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
|
||||
for _, v := range roomNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
for _, v := range userNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
||||
rows, err = txn.QueryContext(ctx, query, params...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
rows, err = s.db.QueryContext(ctx, query, params...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -233,12 +233,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
defer sqlPrep.Close() // nolint:errcheck
|
||||
sqlStmt := sqlutil.TxStmt(txn, sqlPrep)
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
|
||||
rows, err := sqlStmt.QueryContext(ctx, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,11 +108,12 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
intfs[i] = int64(stateBlockNIDs[i])
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := selectStmt.QueryContext(ctx, intfs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -113,11 +113,12 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
defer selectPrep.Close() // nolint:errcheck
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -127,9 +127,8 @@ type Membership interface {
|
|||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||
// counts of how many rooms they are joined.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||
|
|
|
|||
63
run-sytest.sh
Executable file
63
run-sytest.sh
Executable file
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
|
||||
# locally, the Docker image is rebuilt first.
|
||||
#
|
||||
# Logs are stored in ../sytestout/logs.
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
main() {
|
||||
local tag=buster
|
||||
local base_image=debian:$tag
|
||||
local runargs=()
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ -d ../sytest ]; then
|
||||
local tmpdir
|
||||
tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
|
||||
trap "rm -r '$tmpdir'" EXIT
|
||||
|
||||
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
|
||||
echo "Re-building ../sytest Docker images..."
|
||||
|
||||
local status
|
||||
(
|
||||
cd ../sytest
|
||||
|
||||
docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
|
||||
docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
|
||||
) &>"$tmpdir/buildlog" || status=$?
|
||||
if (( status != 0 )); then
|
||||
# Docker is very verbose, and we don't really care about
|
||||
# building SyTest. So we accumulate and only output on
|
||||
# failure.
|
||||
cat "$tmpdir/buildlog" >&2
|
||||
return $status
|
||||
fi
|
||||
fi
|
||||
|
||||
runargs+=( -v "$PWD/../sytest:/sytest:ro" )
|
||||
fi
|
||||
if [ -n "$SYTEST_POSTGRES" ]; then
|
||||
runargs+=( -e POSTGRES=1 )
|
||||
fi
|
||||
|
||||
local sytestout=$PWD/../sytestout
|
||||
mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
|
||||
docker run \
|
||||
--rm \
|
||||
--name "sytest-dendrite-${LOGNAME}" \
|
||||
-e LOGS_USER=$(id -u) \
|
||||
-e LOGS_GROUP=$(id -g) \
|
||||
-v "$PWD:/src/:ro" \
|
||||
-v "$sytestout/logs:/logs/" \
|
||||
-v "$sytestout/cache/go-build:/root/.cache/go-build" \
|
||||
-v "$sytestout/cache/go-pkg:/gopath/pkg" \
|
||||
"${runargs[@]}" \
|
||||
matrixdotorg/sytest-dendrite:latest "$@"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
|
|
@ -30,6 +30,7 @@ import (
|
|||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/atomic"
|
||||
|
|
@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
|||
return f
|
||||
}
|
||||
|
||||
// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
|
||||
func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
||||
return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
|
||||
}
|
||||
|
||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||
|
|
|
|||
|
|
@ -205,6 +205,11 @@ user_api:
|
|||
max_open_conns: 100
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
pusher_database:
|
||||
connection_string: file:pushserver.db
|
||||
max_open_conns: 100
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
tracing:
|
||||
enabled: false
|
||||
jaeger:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ type UserAPI struct {
|
|||
// The length of time an OpenID token is condidered valid in milliseconds
|
||||
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
||||
|
||||
// Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
|
||||
PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
|
||||
|
||||
// The Account database stores the login details and account information
|
||||
// for local users. It is accessed by the UserAPI.
|
||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
|
@ -29,6 +30,7 @@ func JetStreamConsumer(
|
|||
name := durable + "Pull"
|
||||
sub, err := js.PullSubscribe(subj, name, opts...)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return fmt.Errorf("nats.SubscribeSync: %w", err)
|
||||
}
|
||||
go func() {
|
||||
|
|
@ -55,6 +57,7 @@ func JetStreamConsumer(
|
|||
}
|
||||
} else {
|
||||
// Something else went wrong, so we'll panic.
|
||||
sentry.CaptureException(err)
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -64,15 +67,18 @@ func JetStreamConsumer(
|
|||
msg := msgs[0]
|
||||
if err = msg.InProgress(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
||||
sentry.CaptureException(err)
|
||||
continue
|
||||
}
|
||||
if f(ctx, msg) {
|
||||
if err = msg.Ack(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err))
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
} else {
|
||||
if err = msg.Nak(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ var (
|
|||
OutputKeyChangeEvent = "OutputKeyChangeEvent"
|
||||
OutputTypingEvent = "OutputTypingEvent"
|
||||
OutputClientData = "OutputClientData"
|
||||
OutputNotificationData = "OutputNotificationData"
|
||||
OutputReceiptEvent = "OutputReceiptEvent"
|
||||
OutputStreamEvent = "OutputStreamEvent"
|
||||
OutputReadUpdate = "OutputReadUpdate"
|
||||
)
|
||||
|
||||
var streams = []*nats.StreamConfig{
|
||||
|
|
@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
|
|||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputNotificationData,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputStreamEvent,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
{
|
||||
Name: OutputReadUpdate,
|
||||
Retention: nats.InterestPolicy,
|
||||
Storage: nats.FileStorage,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
|
|||
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
|
||||
m.FedClient, m.RoomserverAPI,
|
||||
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
|
||||
m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,
|
||||
&m.Config.MSCs,
|
||||
m.FederationAPI, m.UserAPI, m.KeyAPI,
|
||||
m.ExtPublicRoomsProvider, &m.Config.MSCs,
|
||||
)
|
||||
federationapi.AddPublicRoutes(
|
||||
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue