Merge branch 'master' into ci_multiarch_build_and_push

This commit is contained in:
Neil Alexander 2021-01-18 09:56:43 +00:00 committed by GitHub
commit 0a080dd2ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
116 changed files with 4387 additions and 2375 deletions

View file

@ -185,6 +185,7 @@ linters:
- gocyclo - gocyclo
- goimports # Does everything gofmt does - goimports # Does everything gofmt does
- gosimple - gosimple
- govet
- ineffassign - ineffassign
- megacheck - megacheck
- misspell # Check code comments, whereas misspell in CI checks *.md files - misspell # Check code comments, whereas misspell in CI checks *.md files

View file

@ -1,5 +1,54 @@
# Changelog # Changelog
## Dendrite 0.3.5 (2021-01-11)
### Features
* All `/sync` streams are now logically separate after a refactoring exercise
## Fixes
* Event references are now deeply checked properly when calculating forward extremities, reducing the amount of forward extremities in most cases, which improves RAM utilisation and reduces the work done by state resolution
* Sync no longer sends incorrect `next_batch` tokens with old stream positions, reducing flashbacks of old messages in clients
* The federation `/send` endpoint no longer uses the request context, which could result in some events failing to be persisted if the sending server gave up the HTTP connection
* Appservices can now auth as users in their namespaces properly
## Dendrite 0.3.4 (2020-12-18)
### Features
* The stream tokens for `/sync` have been refactored, giving PDUs, typing notifications, read receipts, invites and send-to-device messages their own respective stream positions, greatly improving the correctness of sync
* A new roominfo cache has been added, which results in less database hits in the roomserver
* Prometheus metrics have been added for sync requests, destination queues and client API event send perceived latency
### Fixes
* Event IDs are no longer recalculated so often in `/sync`, which reduces CPU usage
* Sync requests are now woken up correctly for our own device list updates
* The device list stream position is no longer lost, so unnecessary device updates no longer appear in every other sync
* A crash on concurrent map read/writes has been fixed in the stream token code
* The roomserver input API no longer starts more worker goroutines than needed
* The roomserver no longer uses the request context for queued tasks which could lead to send requests failing to be processed
* A new index has been added to the sync API current state table, which improves lookup performance significantly
* The client API `/joined_rooms` endpoint no longer incorrectly returns `null` if there are 0 rooms joined
* The roomserver will now query appservices when looking up a local room alias that isn't known
* The check on registration for appservice-exclusive namespaces has been fixed
## Dendrite 0.3.3 (2020-12-09)
### Features
* Federation sender should now use considerably less CPU cycles and RAM when sending events into large rooms
* The roomserver now uses considerably less CPU cycles by not calculating event IDs so often
* Experimental support for [MSC2836](https://github.com/matrix-org/matrix-doc/pull/2836) (threading) has been merged
* Dendrite will no longer hold federation HTTP connections open unnecessarily, which should help to reduce ambient CPU/RAM usage and hold fewer long-term file descriptors
### Fixes
* A bug in the latest event updater has been fixed, which should prevent the roomserver from losing forward extremities in some rare cases
* A panic has been fixed when federation is disabled (contributed by [kraem](https://github.com/kraem))
* The response format of the `/joined_members` endpoint has been fixed (contributed by [alexkursell](https://github.com/alexkursell))
## Dendrite 0.3.2 (2020-12-02) ## Dendrite 0.3.2 (2020-12-02)
### Features ### Features

View file

@ -8,7 +8,6 @@ It intends to provide an **efficient**, **reliable** and **scalable** alternativ
a [brand new Go test suite](https://github.com/matrix-org/complement). a [brand new Go test suite](https://github.com/matrix-org/complement).
- Scalable: can run on multiple machines and eventually scale to massive homeserver deployments. - Scalable: can run on multiple machines and eventually scale to massive homeserver deployments.
As of October 2020, Dendrite has now entered **beta** which means: As of October 2020, Dendrite has now entered **beta** which means:
- Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database. - Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database.
- Dendrite has periodic semver releases. We intend to release new versions as we land significant features. - Dendrite has periodic semver releases. We intend to release new versions as we land significant features.
@ -24,7 +23,7 @@ This does not mean:
Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices.
In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode. In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode.
Join us in: If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or join us in:
- **[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org)** - General chat about the Dendrite project, for users and server admins alike - **[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org)** - General chat about the Dendrite project, for users and server admins alike
- **[#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org)** - The place for developers, where all Dendrite development discussion happens - **[#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org)** - The place for developers, where all Dendrite development discussion happens

View file

@ -20,9 +20,9 @@ package api
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -109,7 +109,7 @@ func RetrieveUserProfile(
// If no user exists, return // If no user exists, return
if !userResp.UserIDExists { if !userResp.UserIDExists {
return nil, eventutil.ErrProfileNoExists return nil, errors.New("no known profile for given user ID")
} }
// Try to query the user from the local database again // Try to query the user from the local database again

View file

@ -19,4 +19,4 @@ fi
go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/... go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/...
GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs

View file

@ -4,8 +4,8 @@ These are Docker images for Dendrite!
They can be found on Docker Hub: They can be found on Docker Hub:
- [matrixdotorg/dendrite-monolith](https://hub.docker.com/repository/docker/matrixdotorg/dendrite-monolith) for monolith deployments - [matrixdotorg/dendrite-monolith](https://hub.docker.com/r/matrixdotorg/dendrite-monolith) for monolith deployments
- [matrixdotorg/dendrite-polylith](https://hub.docker.com/repository/docker/matrixdotorg/dendrite-polylith) for polylith deployments - [matrixdotorg/dendrite-polylith](https://hub.docker.com/r/matrixdotorg/dendrite-polylith) for polylith deployments
## Dockerfiles ## Dockerfiles

View file

@ -77,7 +77,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -98,7 +98,7 @@ app_service_api:
connect: http://appservice_api:7777 connect: http://appservice_api:7777
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -173,7 +173,7 @@ federation_sender:
connect: http://federation_sender:7775 connect: http://federation_sender:7775
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_federationsender?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_federationsender?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -199,7 +199,7 @@ key_server:
connect: http://key_server:7779 connect: http://key_server:7779
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_keyserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_keyserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -212,7 +212,7 @@ media_api:
listen: http://0.0.0.0:8074 listen: http://0.0.0.0:8074
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_mediaapi?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_mediaapi?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -248,7 +248,7 @@ room_server:
connect: http://room_server:7770 connect: http://room_server:7770
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_roomserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_roomserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -259,7 +259,7 @@ signing_key_server:
connect: http://signing_key_server:7780 connect: http://signing_key_server:7780
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_signingkeyserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_signingkeyserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -288,7 +288,7 @@ sync_api:
listen: http://0.0.0.0:8073 listen: http://0.0.0.0:8073
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -299,12 +299,12 @@ user_api:
connect: http://user_api:7781 connect: http://user_api:7781
account_database: account_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -3,7 +3,7 @@ services:
# PostgreSQL is needed for both polylith and monolith modes. # PostgreSQL is needed for both polylith and monolith modes.
postgres: postgres:
hostname: postgres hostname: postgres
image: postgres:9.6 image: postgres:11
restart: always restart: always
volumes: volumes:
- ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh

View file

@ -130,6 +130,7 @@ func (m *DendriteMonolith) Start() {
) )
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
ygg.SetSessionFunc(func(address string) { ygg.SetSessionFunc(func(address string) {
req := &api.PerformServersAliveRequest{ req := &api.PerformServersAliveRequest{

View file

@ -8,7 +8,7 @@
# - `DENDRITE_LINT_CONCURRENCY` - number of concurrent linters to run, # - `DENDRITE_LINT_CONCURRENCY` - number of concurrent linters to run,
# golangci-lint defaults this to NumCPU # golangci-lint defaults this to NumCPU
# - `GOGC` - how often to perform garbage collection during golangci-lint runs. # - `GOGC` - how often to perform garbage collection during golangci-lint runs.
# Essentially a ratio of memory/speed. See https://github.com/golangci/golangci-lint#memory-usage-of-golangci-lint # Essentially a ratio of memory/speed. See https://golangci-lint.run/usage/performance/#memory-usage
# for more info. # for more info.
@ -24,8 +24,6 @@ fi
echo "Installing golangci-lint..." echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first # Make a backup of go.{mod,sum} first
# TODO: Once go 1.13 is out, use go get's -mod=readonly option
# https://github.com/golang/go/issues/30667
cp go.mod go.mod.bak && cp go.sum go.sum.bak cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.19.1 go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.19.1

View file

@ -111,6 +111,9 @@ func GetJoinedRooms(
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if res.RoomIDs == nil {
res.RoomIDs = []string{}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getJoinedRoomsResponse{res.RoomIDs}, JSON: getJoinedRoomsResponse{res.RoomIDs},

View file

@ -328,7 +328,22 @@ func UserIDIsWithinApplicationServiceNamespace(
userID string, userID string,
appservice *config.ApplicationService, appservice *config.ApplicationService,
) bool { ) bool {
var local, domain, err = gomatrixserverlib.SplitID('@', userID)
if err != nil {
// Not a valid userID
return false
}
if domain != cfg.Matrix.ServerName {
return false
}
if appservice != nil { if appservice != nil {
if appservice.SenderLocalpart == local {
return true
}
// Loop through given application service's namespaces and see if any match // Loop through given application service's namespaces and see if any match
for _, namespace := range appservice.NamespaceMap["users"] { for _, namespace := range appservice.NamespaceMap["users"] {
// AS namespaces are checked for validity in config // AS namespaces are checked for validity in config
@ -341,6 +356,9 @@ func UserIDIsWithinApplicationServiceNamespace(
// Loop through all known application service's namespaces and see if any match // Loop through all known application service's namespaces and see if any match
for _, knownAppService := range cfg.Derived.ApplicationServices { for _, knownAppService := range cfg.Derived.ApplicationServices {
if knownAppService.SenderLocalpart == local {
return true
}
for _, namespace := range knownAppService.NamespaceMap["users"] { for _, namespace := range knownAppService.NamespaceMap["users"] {
// AS namespaces are checked for validity in config // AS namespaces are checked for validity in config
if namespace.RegexpObject.MatchString(userID) { if namespace.RegexpObject.MatchString(userID) {
@ -488,17 +506,6 @@ func Register(
return *resErr return *resErr
} }
// Make sure normal user isn't registering under an exclusive application
// service namespace. Skip this check if no app services are registered.
if r.Auth.Type != authtypes.LoginTypeApplicationService &&
len(cfg.Derived.ApplicationServices) != 0 &&
UsernameMatchesExclusiveNamespaces(cfg, r.Username) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.ASExclusive("This username is reserved by an application service."),
}
}
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"username": r.Username, "username": r.Username,
@ -581,11 +588,33 @@ func handleRegistrationFlow(
// TODO: Handle mapping registrationRequest parameters into session parameters // TODO: Handle mapping registrationRequest parameters into session parameters
// TODO: email / msisdn auth types. // TODO: email / msisdn auth types.
accessToken, accessTokenErr := auth.ExtractAccessToken(req)
// Appservices are special and are not affected by disabled
// registration or user exclusivity.
if r.Auth.Type == authtypes.LoginTypeApplicationService ||
(r.Auth.Type == "" && accessTokenErr == nil) {
return handleApplicationServiceRegistration(
accessToken, accessTokenErr, req, r, cfg, userAPI,
)
}
if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret {
return util.MessageResponse(http.StatusForbidden, "Registration has been disabled") return util.MessageResponse(http.StatusForbidden, "Registration has been disabled")
} }
// Make sure normal user isn't registering under an exclusive application
// service namespace. Skip this check if no app services are registered.
// If an access token is provided, ignore this check this is an appservice
// request and we will validate in validateApplicationService
if len(cfg.Derived.ApplicationServices) != 0 &&
UsernameMatchesExclusiveNamespaces(cfg, r.Username) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.ASExclusive("This username is reserved by an application service."),
}
}
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
@ -611,36 +640,15 @@ func handleRegistrationFlow(
// Add SharedSecret to the list of completed registration stages // Add SharedSecret to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeSharedSecret) AddCompletedSessionStage(sessionID, authtypes.LoginTypeSharedSecret)
case "":
// Extract the access token from the request, if there's one to extract
// (which we can know by checking whether the error is nil or not).
accessToken, err := auth.ExtractAccessToken(req)
// A missing auth type can mean either the registration is performed by
// an AS or the request is made as the first step of a registration
// using the User-Interactive Authentication API. This can be determined
// by whether the request contains an access token.
if err == nil {
return handleApplicationServiceRegistration(
accessToken, err, req, r, cfg, userAPI,
)
}
case authtypes.LoginTypeApplicationService:
// Extract the access token from the request.
accessToken, err := auth.ExtractAccessToken(req)
// Let the AS registration handler handle the process from here. We
// don't need a condition on that call since the registration is clearly
// stated as being AS-related.
return handleApplicationServiceRegistration(
accessToken, err, req, r, cfg, userAPI,
)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case "":
// An empty auth type means that we want to fetch the available
// flows. It can also mean that we want to register as an appservice
// but that is handed above.
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -27,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,6 +42,25 @@ var (
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
) )
func init() {
prometheus.MustRegister(sendEventDuration)
}
var sendEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "clientapi",
Name: "sendevent_duration_millis",
Help: "How long it takes to build and submit a new event from the client API to the roomserver",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"action"},
)
// SendEvent implements: // SendEvent implements:
// /rooms/{roomID}/send/{eventType} // /rooms/{roomID}/send/{eventType}
// /rooms/{roomID}/send/{eventType}/{txnID} // /rooms/{roomID}/send/{eventType}/{txnID}
@ -75,10 +96,12 @@ func SendEvent(
mutex.(*sync.Mutex).Lock() mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
timeToGenerateEvent := time.Since(startedGeneratingEvent)
var txnAndSessionID *api.TransactionID var txnAndSessionID *api.TransactionID
if txnID != nil { if txnID != nil {
@ -90,6 +113,7 @@ func SendEvent(
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents( if err := api.SendEvents(
req.Context(), rsAPI, req.Context(), rsAPI,
api.KindNew, api.KindNew,
@ -102,6 +126,7 @@ func SendEvent(
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
timeToSubmitEvent := time.Since(startedSubmittingEvent)
util.GetLogger(req.Context()).WithFields(logrus.Fields{ util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": e.EventID(), "event_id": e.EventID(),
"room_id": roomID, "room_id": roomID,
@ -117,6 +142,11 @@ func SendEvent(
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, &res)
} }
// Take a note of how long it took to generate the event vs submit
// it to the roomserver.
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res return res
} }

View file

@ -161,6 +161,7 @@ func main() {
&base.Base, cache.New(), userAPI, &base.Base, cache.New(), userAPI,
) )
asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
&base.Base, federation, rsAPI, keyRing, &base.Base, federation, rsAPI, keyRing,
) )

View file

@ -113,6 +113,7 @@ func main() {
) )
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, keyRing, base, federation, rsAPI, keyRing,
) )

View file

@ -126,6 +126,7 @@ func main() {
appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) appservice.AddInternalRoutes(base.InternalAPIMux, asAPI)
asAPI = base.AppserviceHTTPClient() asAPI = base.AppserviceHTTPClient()
} }
rsAPI.SetAppserviceAPI(asAPI)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,

View file

@ -24,9 +24,11 @@ func RoomServer(base *setup.BaseDendrite, cfg *config.Dendrite) {
serverKeyAPI := base.SigningKeyServerHTTPClient() serverKeyAPI := base.SigningKeyServerHTTPClient()
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
asAPI := base.AppserviceHTTPClient()
fsAPI := base.FederationSenderHTTPClient() fsAPI := base.FederationSenderHTTPClient()
rsAPI := roomserver.NewInternalAPI(base, keyRing) rsAPI := roomserver.NewInternalAPI(base, keyRing)
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
rsAPI.SetAppserviceAPI(asAPI)
roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI)
base.SetupAndServeHTTP( base.SetupAndServeHTTP(

View file

@ -207,6 +207,7 @@ func main() {
asQuery := appservice.NewInternalAPI( asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI, base, userAPI, rsAPI,
) )
rsAPI.SetAppserviceAPI(asQuery)
fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing)
rsAPI.SetFederationSenderAPI(fedSenderAPI) rsAPI.SetFederationSenderAPI(fedSenderAPI)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)

View file

@ -63,8 +63,10 @@ func main() {
if *defaultsForCI { if *defaultsForCI {
cfg.ClientAPI.RateLimiting.Enabled = false cfg.ClientAPI.RateLimiting.Enabled = false
cfg.FederationSender.DisableTLSValidation = true cfg.FederationSender.DisableTLSValidation = true
cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.Logging[0].Level = "trace" cfg.Logging[0].Level = "trace"
// don't hit matrix.org when running tests!!!
cfg.SigningKeyServer.KeyPerspectives = config.KeyPerspectives{}
} }
j, err := yaml.Marshal(cfg) j, err := yaml.Marshal(cfg)

View file

@ -89,7 +89,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: file:naffka.db connection_string: file:naffka.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -110,7 +110,7 @@ app_service_api:
connect: http://localhost:7777 connect: http://localhost:7777
database: database:
connection_string: file:appservice.db connection_string: file:appservice.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -185,7 +185,7 @@ federation_sender:
connect: http://localhost:7775 connect: http://localhost:7775
database: database:
connection_string: file:federationsender.db connection_string: file:federationsender.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -211,7 +211,7 @@ key_server:
connect: http://localhost:7779 connect: http://localhost:7779
database: database:
connection_string: file:keyserver.db connection_string: file:keyserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -224,7 +224,7 @@ media_api:
listen: http://[::]:8074 listen: http://[::]:8074
database: database:
connection_string: file:mediaapi.db connection_string: file:mediaapi.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -253,6 +253,19 @@ media_api:
height: 480 height: 480
method: scale method: scale
# Configuration for experimental MSC's
mscs:
# A list of enabled MSC's
# Currently valid values are:
# - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
mscs: []
database:
connection_string: file:mscs.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Configuration for the Room Server. # Configuration for the Room Server.
room_server: room_server:
internal_api: internal_api:
@ -260,7 +273,7 @@ room_server:
connect: http://localhost:7770 connect: http://localhost:7770
database: database:
connection_string: file:roomserver.db connection_string: file:roomserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -271,7 +284,7 @@ signing_key_server:
connect: http://localhost:7780 connect: http://localhost:7780
database: database:
connection_string: file:signingkeyserver.db connection_string: file:signingkeyserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -300,7 +313,7 @@ sync_api:
listen: http://[::]:8073 listen: http://[::]:8073
database: database:
connection_string: file:syncapi.db connection_string: file:syncapi.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -316,12 +329,12 @@ user_api:
connect: http://localhost:7781 connect: http://localhost:7781
account_database: account_database:
connection_string: file:userapi_accounts.db connection_string: file:userapi_accounts.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: file:userapi_devices.db connection_string: file:userapi_devices.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -2,13 +2,13 @@
In addition to standard Go code style (`gofmt`, `goimports`), we use `golangci-lint` In addition to standard Go code style (`gofmt`, `goimports`), we use `golangci-lint`
to run a number of linters, the exact list can be found under linters in [.golangci.yml](.golangci.yml). to run a number of linters, the exact list can be found under linters in [.golangci.yml](.golangci.yml).
[Installation](https://github.com/golangci/golangci-lint#install) and [Editor [Installation](https://github.com/golangci/golangci-lint#install-golangci-lint) and [Editor
Integration](https://github.com/golangci/golangci-lint#editor-integration) for Integration](https://golangci-lint.run/usage/integrations/#editor-integration) for
it can be found in the readme of golangci-lint. it can be found in the readme of golangci-lint.
For rare cases where a linter is giving a spurious warning, it can be disabled For rare cases where a linter is giving a spurious warning, it can be disabled
for that line or statement using a [comment for that line or statement using a [comment
directive](https://github.com/golangci/golangci-lint#nolint), e.g. `var directive](https://golangci-lint.run/usage/false-positives/#nolint), e.g. `var
bad_name int //nolint:golint,unused`. This should be used sparingly and only bad_name int //nolint:golint,unused`. This should be used sparingly and only
when its clear that the lint warning is spurious. when its clear that the lint warning is spurious.

View file

@ -12,6 +12,10 @@ No, although a good portion of the Matrix specification has been implemented. Mo
No, not at present. There will be in the future when Dendrite reaches version 1.0. No, not at present. There will be in the future when Dendrite reaches version 1.0.
### Should I run a monolith or a polylith deployment?
Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines.
### I've installed Dendrite but federation isn't working ### I've installed Dendrite but federation isn't working
Check the [Federation Tester](https://federationtester.matrix.org). You need at least: Check the [Federation Tester](https://federationtester.matrix.org). You need at least:

View file

@ -80,12 +80,6 @@ brew services start kafka
## Configuration ## Configuration
### SQLite database setup
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### PostgreSQL database setup ### PostgreSQL database setup
Assuming that PostgreSQL 9.6 (or later) is installed: Assuming that PostgreSQL 9.6 (or later) is installed:
@ -96,7 +90,23 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
sudo -u postgres createuser -P dendrite sudo -u postgres createuser -P dendrite
``` ```
* Create the component databases: At this point you have a choice on whether to run all of the Dendrite
components from a single database, or for each component to have its
own database. For most deployments, running from a single database will
be sufficient, although you may wish to separate them if you plan to
split out the databases across multiple machines in the future.
On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run all Dendrite components from a single database:
```bash
sudo -u postgres createdb -O dendrite dendrite
```
... in which case your connection string will look like `postgres://user:pass@database/dendrite`.
* If you want to run each Dendrite component with its own database:
```bash ```bash
for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do
@ -104,14 +114,22 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
done done
``` ```
(On macOS, omit `sudo -u postgres` from the above commands.) ... in which case your connection string will look like `postgres://user:pass@database/dendrite_componentname`.
### SQLite database setup
**WARNING:** SQLite is suitable for small experimental deployments only and should not be used in production - use PostgreSQL instead for any user-facing federating installation!
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### Server key generation ### Server key generation
Each Dendrite installation requires: Each Dendrite installation requires:
- A unique Matrix signing private key * A unique Matrix signing private key
- A valid and trusted TLS certificate and private key * A valid and trusted TLS certificate and private key
To generate a Matrix signing private key: To generate a Matrix signing private key:
@ -119,7 +137,7 @@ To generate a Matrix signing private key:
./bin/generate-keys --private-key matrix_key.pem ./bin/generate-keys --private-key matrix_key.pem
``` ```
**Warning:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or **WARNING:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or
any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining
federated rooms. federated rooms.
@ -140,7 +158,9 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th
* The `server_name` entry to reflect the hostname of your Dendrite server * The `server_name` entry to reflect the hostname of your Dendrite server
* The `database` lines with an updated connection string based on your * The `database` lines with an updated connection string based on your
desired setup, e.g. replacing `database` with the name of the database: desired setup, e.g. replacing `database` with the name of the database:
* For Postgres: `postgres://dendrite:password@localhost/database`, e.g. `postgres://dendrite:password@localhost/dendrite_userapi_account.db` * For Postgres: `postgres://dendrite:password@localhost/database`, e.g.
* `postgres://dendrite:password@localhost/dendrite_userapi_account` to connect to PostgreSQL with SSL/TLS
* `postgres://dendrite:password@localhost/dendrite_userapi_account?sslmode=disable` to connect to PostgreSQL without SSL/TLS
* For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db` * For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db`
* Postgres and SQLite can be mixed and matched on different components as desired. * Postgres and SQLite can be mixed and matched on different components as desired.
* The `use_naffka` option if using Naffka in a monolith deployment * The `use_naffka` option if using Naffka in a monolith deployment
@ -295,4 +315,3 @@ amongst other things.
```bash ```bash
./bin/dendrite-polylith-multi --config=dendrite.yaml userapi ./bin/dendrite-polylith-multi --config=dendrite.yaml userapi
``` ```

View file

@ -5,6 +5,7 @@ After=network.target
After=postgresql.service After=postgresql.service
[Service] [Service]
Environment=GODEBUG=madvdontneed=1
RestartSec=2s RestartSec=2s
Type=simple Type=simple
User=dendrite User=dendrite

View file

@ -113,19 +113,6 @@ func (t *EDUCache) AddTypingUser(
return t.GetLatestSyncPosition() return t.GetLatestSyncPosition()
} }
// AddSendToDeviceMessage increases the sync position for
// send-to-device updates.
// Returns the sync position before update, as the caller
// will use this to record the current stream position
// at the time that the send-to-device message was sent.
func (t *EDUCache) AddSendToDeviceMessage() int64 {
t.Lock()
defer t.Unlock()
latestSyncPosition := t.latestSyncPosition
t.latestSyncPosition++
return latestSyncPosition
}
// addUser with mutex lock & replace the previous timer. // addUser with mutex lock & replace the previous timer.
// Returns the latest typing sync position after update. // Returns the latest typing sync position after update.
func (t *EDUCache) addUser( func (t *EDUCache) addUser(

View file

@ -84,7 +84,7 @@ func Send(
util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs))
resp, jsonErr := t.processTransaction(httpReq.Context()) resp, jsonErr := t.processTransaction(context.Background())
if jsonErr != nil { if jsonErr != nil {
util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed")
return *jsonErr return *jsonErr
@ -1005,79 +1005,82 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion)
} }
util.GetLogger(ctx).WithFields(logrus.Fields{ if missingCount > 0 {
"missing": missingCount, util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": eventID, "missing": missingCount,
"room_id": roomID, "event_id": eventID,
"total_state": len(stateIDs.StateEventIDs), "room_id": roomID,
"total_auth_events": len(stateIDs.AuthEventIDs), "total_state": len(stateIDs.StateEventIDs),
"concurrent_requests": concurrentRequests, "total_auth_events": len(stateIDs.AuthEventIDs),
}).Info("Fetching missing state at event") "concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event")
// Get a list of servers to fetch from. // Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID) servers := t.getServers(ctx, roomID)
if len(servers) > 5 { if len(servers) > 5 {
servers = servers[:5] servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want
// to retrieve.
pending := make(chan string, missingCount)
for missingEventID := range missing {
pending <- missingEventID
}
close(pending)
// Define how many workers we should start to do this.
if missingCount < concurrentRequests {
concurrentRequests = missingCount
}
// Create the wait group.
var fetchgroup sync.WaitGroup
fetchgroup.Add(concurrentRequests)
// This is the only place where we'll write to t.haveEvents from
// multiple goroutines, and everywhere else is blocked on this
// synchronous function anyway.
var haveEventsMutex sync.Mutex
// Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) {
case verifySigError:
return
case nil:
break
default:
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": missingEventID,
"room_id": roomID,
}).Info("Failed to fetch missing event")
return
} }
haveEventsMutex.Lock()
t.haveEvents[h.EventID()] = h
haveEventsMutex.Unlock()
}
// Create the worker. // Create a queue containing all of the missing event IDs that we want
worker := func(ch <-chan string) { // to retrieve.
defer fetchgroup.Done() pending := make(chan string, missingCount)
for missingEventID := range ch { for missingEventID := range missing {
fetch(missingEventID) pending <- missingEventID
} }
close(pending)
// Define how many workers we should start to do this.
if missingCount < concurrentRequests {
concurrentRequests = missingCount
}
// Create the wait group.
var fetchgroup sync.WaitGroup
fetchgroup.Add(concurrentRequests)
// This is the only place where we'll write to t.haveEvents from
// multiple goroutines, and everywhere else is blocked on this
// synchronous function anyway.
var haveEventsMutex sync.Mutex
// Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) {
case verifySigError:
return
case nil:
break
default:
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": missingEventID,
"room_id": roomID,
}).Info("Failed to fetch missing event")
return
}
haveEventsMutex.Lock()
t.haveEvents[h.EventID()] = h
haveEventsMutex.Unlock()
}
// Create the worker.
worker := func(ch <-chan string) {
defer fetchgroup.Done()
for missingEventID := range ch {
fetch(missingEventID)
}
}
// Start the workers.
for i := 0; i < concurrentRequests; i++ {
go worker(pending)
}
// Wait for the workers to finish.
fetchgroup.Wait()
} }
// Start the workers.
for i := 0; i < concurrentRequests; i++ {
go worker(pending)
}
// Wait for the workers to finish.
fetchgroup.Wait()
resp, err := t.createRespStateFromStateIDs(stateIDs) resp, err := t.createRespStateFromStateIDs(stateIDs)
return resp, err return resp, err
} }

View file

@ -35,6 +35,8 @@ import (
const ( const (
maxPDUsPerTransaction = 50 maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50 maxEDUsPerTransaction = 50
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30 queueIdleTimeout = time.Second * 30
) )
@ -51,54 +53,56 @@ type destinationQueue struct {
destination gomatrixserverlib.ServerName // destination of requests destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running? running atomic.Bool // is the queue worker running?
backingOff atomic.Bool // true if we're backing off backingOff atomic.Bool // true if we're backing off
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
statistics *statistics.ServerStatistics // statistics about this remote server statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
transactionCount atomic.Int32 // how many events in this transaction so far notify chan struct{} // interrupts idle wait pending PDUs/EDUs
notifyPDUs chan bool // interrupts idle wait for PDUs pendingPDUs []*queuedPDU // PDUs waiting to be sent
notifyEDUs chan bool // interrupts idle wait for EDUs pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to // If the queue is empty then it starts a background goroutine to
// start sending events to that destination. // start sending events to that destination.
func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
// Create a transaction ID. We'll either do this if we don't have if event == nil {
// one made up yet, or if we've exceeded the number of maximum log.Errorf("attempt to send nil PDU with destination %q", oq.destination)
// events allowed in a single tranaction. We'll reset the counter return
// when we do.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" || oq.transactionCount.Load() >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount.Store(0)
} }
oq.transactionIDMutex.Unlock()
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
if err := oq.db.AssociatePDUWithDestination( if err := oq.db.AssociatePDUWithDestination(
context.TODO(), context.TODO(),
oq.transactionID, // the current transaction ID "", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationsender_queue_json table receipt, // NIDs from federationsender_queue_json table
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to associate PDU receipt %q with destination %q", receipt.String(), oq.destination) log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return return
} }
// We've successfully added a PDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep. // Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() oq.wakeQueueIfNeeded()
// If we're blocking on waiting PDUs then tell the queue that we
// have work to do.
select { select {
case oq.notifyPDUs <- true: case oq.notify <- struct{}{}:
default: default:
} }
} }
@ -107,7 +111,11 @@ func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) {
// sendEDU adds the EDU event to the pending queue for the destination. // sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to // If the queue is empty then it starts a background goroutine to
// start sending events to that destination. // start sending events to that destination.
func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) { func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
if event == nil {
log.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
@ -116,21 +124,28 @@ func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) {
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationsender_queue_json table receipt, // NIDs from federationsender_queue_json table
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to associate EDU receipt %q with destination %q", receipt.String(), oq.destination) log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return return
} }
// We've successfully added an EDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep. // Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() oq.wakeQueueIfNeeded()
// If we're blocking on waiting EDUs then tell the queue that we
// have work to do.
select { select {
case oq.notifyEDUs <- true: case oq.notify <- struct{}{}:
default: default:
} }
} }
@ -152,48 +167,71 @@ func (oq *destinationQueue) wakeQueueIfNeeded() {
} }
} }
// waitForPDUs returns a channel for pending PDUs, which will be // getPendingFromDatabase will look at the database and see if
// used in backgroundSend select. It returns a closed channel if // there are any persisted events that haven't been sent to this
// there is something pending right now, or an open channel if // destination yet. If so, they will be queued up.
// we're waiting for something. // nolint:gocyclo
func (oq *destinationQueue) waitForPDUs() chan bool { func (oq *destinationQueue) getPendingFromDatabase() {
pendingPDUs, err := oq.db.GetPendingPDUCount(context.TODO(), oq.destination) // Check to see if there's anything to do for this server
if err != nil { // in the database.
log.WithError(err).Errorf("Failed to get pending PDU count on queue %q", oq.destination) retrieved := false
} ctx := context.Background()
// If there are PDUs pending right now then we'll return a closed oq.pendingMutex.Lock()
// channel. This will mean that the backgroundSend will not block. defer oq.pendingMutex.Unlock()
if pendingPDUs > 0 {
ch := make(chan bool, 1)
close(ch)
return ch
}
// If there are no PDUs pending right now then instead we'll return
// the notify channel, so that backgroundSend can pick up normal
// notifications from sendEvent.
return oq.notifyPDUs
}
// waitForEDUs returns a channel for pending EDUs, which will be // Take a note of all of the PDUs and EDUs that we already
// used in backgroundSend select. It returns a closed channel if // have cached. We will index them based on the receipt,
// there is something pending right now, or an open channel if // which ultimately just contains the index of the PDU/EDU
// we're waiting for something. // in the database.
func (oq *destinationQueue) waitForEDUs() chan bool { gotPDUs := map[string]struct{}{}
pendingEDUs, err := oq.db.GetPendingEDUCount(context.TODO(), oq.destination) gotEDUs := map[string]struct{}{}
if err != nil { for _, pdu := range oq.pendingPDUs {
log.WithError(err).Errorf("Failed to get pending EDU count on queue %q", oq.destination) gotPDUs[pdu.receipt.String()] = struct{}{}
} }
// If there are EDUs pending right now then we'll return a closed for _, edu := range oq.pendingEDUs {
// channel. This will mean that the backgroundSend will not block. gotEDUs[edu.receipt.String()] = struct{}{}
if pendingEDUs > 0 { }
ch := make(chan bool, 1)
close(ch) if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
return ch // We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
oq.overflowed.Store(false)
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
select {
case oq.notify <- struct{}{}:
default:
}
} }
// If there are no EDUs pending right now then instead we'll return
// the notify channel, so that backgroundSend can pick up normal
// notifications from sendEvent.
return oq.notifyEDUs
} }
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
@ -204,27 +242,32 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CAS(false, true) { if !oq.running.CAS(false, true) {
return return
} }
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.running.Store(false) defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
oq.overflowed.Store(true)
for { for {
pendingPDUs, pendingEDUs := false, false // If we are overflowing memory and have sent things out to the
// database then we can look up what those things are.
if oq.overflowed.Load() {
oq.getPendingFromDatabase()
}
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
select { select {
case <-oq.waitForPDUs(): case <-oq.notify:
// We were woken up because there are new PDUs waiting in the // There's work to do, either because getPendingFromDatabase
// database. // told us there is, or because a new event has come in via
pendingPDUs = true // sendEvent/sendEDU.
case <-oq.waitForEDUs():
// We were woken up because there are new PDUs waiting in the
// database.
pendingEDUs = true
case <-time.After(queueIdleTimeout): case <-time.After(queueIdleTimeout):
// The worker is idle so stop the goroutine. It'll get // The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
log.Tracef("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout)
return return
} }
@ -237,6 +280,16 @@ func (oq *destinationQueue) backgroundSend() {
// has exceeded a maximum allowable value. Clean up the in-memory // has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer. // buffers at this point. The PDU clean-up is already on a defer.
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return return
} }
if until != nil && until.After(time.Now()) { if until != nil && until.After(time.Now()) {
@ -244,24 +297,51 @@ func (oq *destinationQueue) backgroundSend() {
// time. // time.
duration := time.Until(*until) duration := time.Until(*until)
log.Warnf("Backing off %q for %s", oq.destination, duration) log.Warnf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select { select {
case <-time.After(duration): case <-time.After(duration):
case <-oq.interruptBackoff: case <-oq.interruptBackoff:
} }
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
} }
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
eduCount := len(oq.pendingEDUs)
if pduCount > maxPDUsPerTransaction {
pduCount = maxPDUsPerTransaction
}
if eduCount > maxEDUsPerTransaction {
eduCount = maxEDUsPerTransaction
}
toSendPDUs := oq.pendingPDUs[:pduCount]
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if pendingPDUs || pendingEDUs { // Try sending the next transaction and see what happens.
// Try sending the next transaction and see what happens. transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
transaction, terr := oq.nextTransaction() if terr != nil {
if terr != nil { // We failed to send the transaction. Mark it as a failure.
// We failed to send the transaction. Mark it as a failure. oq.statistics.Failure()
oq.statistics.Failure()
} else if transaction { } else if transaction {
// If we successfully sent the transaction then clear out // If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID. // the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success() oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pc] {
oq.pendingPDUs[i] = nil
} }
for i := range oq.pendingEDUs[:ec] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} }
} }
} }
@ -270,16 +350,20 @@ func (oq *destinationQueue) backgroundSend() {
// queue and sends it. Returns true if a transaction was sent or // queue and sends it. Returns true if a transaction was sent or
// false otherwise. // false otherwise.
// nolint:gocyclo // nolint:gocyclo
func (oq *destinationQueue) nextTransaction() (bool, error) { func (oq *destinationQueue) nextTransaction(
// Before we do anything, we need to roll over the transaction pdus []*queuedPDU,
// ID that is being used to coalesce events into the next TX. edus []*queuedEDU,
// Otherwise it's possible that we'll pick up an incomplete ) (bool, int, int, error) {
// transaction and end up nuking the rest of the events at the // If there's no projected transaction ID then generate one. If
// cleanup stage. // the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock() oq.transactionIDMutex.Unlock()
oq.transactionCount.Store(0)
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t := gomatrixserverlib.Transaction{
@ -289,58 +373,36 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
t.Origin = oq.origin t.Origin = oq.origin
t.Destination = oq.destination t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// Ask the database for any pending PDUs from the next transaction.
// maxPDUsPerTransaction is an upper limit but we probably won't
// actually retrieve that many events.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
txid, pdus, pduReceipt, err := oq.db.GetNextTransactionPDUs(
ctx, // context
oq.destination, // server name
maxPDUsPerTransaction, // max events to retrieve
)
if err != nil {
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
}
edus, eduReceipt, err := oq.db.GetNextTransactionEDUs(
ctx, // context
oq.destination, // server name
maxEDUsPerTransaction, // max events to retrieve
)
if err != nil {
log.WithError(err).Errorf("failed to get next transaction EDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionEDUs: %w", err)
}
// If we didn't get anything from the database and there are no // If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here. // pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 { if len(pdus) == 0 && len(edus) == 0 {
return false, nil return false, 0, 0, nil
} }
// Pick out the transaction ID from the database. If we didn't var pduReceipts []*shared.Receipt
// get a transaction ID (i.e. because there are no PDUs but only var eduReceipts []*shared.Receipt
// EDUs) then generate a transaction ID.
t.TransactionID = txid
if t.TransactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
// Go through PDUs that we retrieved from the database, if any, // Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction. // and add them into the transaction.
for _, pdu := range pdus { for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the // Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct // gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, (*pdu).JSON()) t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
} }
// Do the same for pending EDUS in the queue. // Do the same for pending EDUS in the queue.
for _, edu := range edus { for _, edu := range edus {
t.EDUs = append(t.EDUs, *edu) if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
} }
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
@ -349,34 +411,38 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
// TODO: we should check for 500-ish fails vs 400-ish here, // TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response // since we shouldn't queue things indefinitely in response
// to a 400-ish error // to a 400-ish error
ctx, cancel = context.WithTimeout(context.Background(), time.Minute*5) ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel() defer cancel()
_, err = oq.client.SendTransaction(ctx, t) _, err := oq.client.SendTransaction(ctx, t)
switch err.(type) { switch err.(type) {
case nil: case nil:
// Clean up the transaction in the database. // Clean up the transaction in the database.
if pduReceipt != nil { if pduReceipts != nil {
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String()) //logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipt); err != nil { if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
log.WithError(err).Errorf("failed to clean PDUs %q for server %q", pduReceipt.String(), t.Destination) log.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
} }
} }
if eduReceipt != nil { if eduReceipts != nil {
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String()) //logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipt); err != nil { if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
log.WithError(err).Errorf("failed to clean EDUs %q for server %q", eduReceipt.String(), t.Destination) log.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
} }
} }
return true, nil // Reset the transaction ID.
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil
case gomatrix.HTTPError: case gomatrix.HTTPError:
// Report that we failed to send the transaction and we // Report that we failed to send the transaction and we
// will retry again, subject to backoff. // will retry again, subject to backoff.
return false, err return false, 0, 0, err
default: default:
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"destination": oq.destination, "destination": oq.destination,
log.ErrorKey: err, log.ErrorKey: err,
}).Info("problem sending transaction") }).Infof("Failed to send transaction %q", t.TransactionID)
return false, err return false, 0, 0, err
} }
} }

View file

@ -24,8 +24,10 @@ import (
"github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/statistics"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@ -44,6 +46,37 @@ type OutgoingQueues struct {
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
func init() {
prometheus.MustRegister(
destinationQueueTotal, destinationQueueRunning,
destinationQueueBackingOff,
)
}
var destinationQueueTotal = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_total",
},
)
var destinationQueueRunning = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_running",
},
)
var destinationQueueBackingOff = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_backing_off",
},
)
// NewOutgoingQueues makes a new OutgoingQueues // NewOutgoingQueues makes a new OutgoingQueues
func NewOutgoingQueues( func NewOutgoingQueues(
db storage.Database, db storage.Database,
@ -83,8 +116,8 @@ func NewOutgoingQueues(
log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") log.WithError(err).Error("Failed to get EDU server names for destination queue hydration")
} }
for serverName := range serverNames { for serverName := range serverNames {
if !queues.getQueue(serverName).statistics.Blacklisted() { if queue := queues.getQueue(serverName); !queue.statistics.Blacklisted() {
queues.getQueue(serverName).wakeQueueIfNeeded() queue.wakeQueueIfNeeded()
} }
} }
}) })
@ -100,11 +133,22 @@ type SigningInfo struct {
PrivateKey ed25519.PrivateKey PrivateKey ed25519.PrivateKey
} }
type queuedPDU struct {
receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *shared.Receipt
edu *gomatrixserverlib.EDU
}
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
oqs.queuesMutex.Lock() oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
oq := oqs.queues[destination] oq := oqs.queues[destination]
if oq == nil { if oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{ oq = &destinationQueue{
db: oqs.db, db: oqs.db,
rsAPI: oqs.rsAPI, rsAPI: oqs.rsAPI,
@ -112,8 +156,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
destination: destination, destination: destination,
client: oqs.client, client: oqs.client,
statistics: oqs.statistics.ForServer(destination), statistics: oqs.statistics.ForServer(destination),
notifyPDUs: make(chan bool, 1), notify: make(chan struct{}, 1),
notifyEDUs: make(chan bool, 1),
interruptBackoff: make(chan bool), interruptBackoff: make(chan bool),
signing: oqs.signing, signing: oqs.signing,
} }
@ -188,7 +231,7 @@ func (oqs *OutgoingQueues) SendEvent(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEvent(nid) oqs.getQueue(destination).sendEvent(ev, nid)
} }
return nil return nil
@ -258,7 +301,7 @@ func (oqs *OutgoingQueues) SendEDU(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEDU(nid) oqs.getQueue(destination).sendEDU(e, nid)
} }
return nil return nil

View file

@ -36,14 +36,14 @@ type Database interface {
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, *shared.Receipt, error) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
GetNextTransactionEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) ([]*gomatrixserverlib.EDU, *shared.Receipt, error) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)

View file

@ -45,16 +45,10 @@ const insertQueuePDUSQL = "" +
const deleteQueuePDUSQL = "" + const deleteQueuePDUSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)" "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueuePDUNextTransactionIDSQL = "" + const selectQueuePDUsSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
" LIMIT 1"
const selectQueuePDUsByTransactionSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" + "SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1 AND transaction_id = $2" + " WHERE server_name = $1" +
" LIMIT $3" " LIMIT $2"
const selectQueuePDUReferenceJSONCountSQL = "" + const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
@ -71,8 +65,7 @@ type queuePDUsStatements struct {
db *sql.DB db *sql.DB
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
deleteQueuePDUsStmt *sql.Stmt deleteQueuePDUsStmt *sql.Stmt
selectQueuePDUNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt
selectQueuePDUReferenceJSONCountStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt
selectQueuePDUServerNamesStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt
@ -92,10 +85,7 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil { if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil {
return return
} }
if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil { if s.selectQueuePDUsStmt, err = s.db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil {
return return
} }
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
@ -137,18 +127,6 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
return err return err
} }
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
if err == sql.ErrNoRows {
return "", nil
}
return transactionID, err
}
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
@ -182,11 +160,10 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -17,7 +17,6 @@ package shared
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/federationsender/storage/tables" "github.com/matrix-org/dendrite/federationsender/storage/tables"
@ -44,16 +43,11 @@ type Database struct {
// to pass them back so that we can clean up if the transaction sends // to pass them back so that we can clean up if the transaction sends
// successfully. // successfully.
type Receipt struct { type Receipt struct {
nids []int64 nid int64
} }
func (e *Receipt) Empty() bool { func (r *Receipt) String() string {
return len(e.nids) == 0 return fmt.Sprintf("%d", r.nid)
}
func (e *Receipt) String() string {
j, _ := json.Marshal(e.nids)
return string(j)
} }
// UpdateRoom updates the joined hosts for a room and returns what the joined // UpdateRoom updates the joined hosts for a room and returns what the joined
@ -146,7 +140,7 @@ func (d *Database) StoreJSON(
return nil, fmt.Errorf("d.insertQueueJSON: %w", err) return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
} }
return &Receipt{ return &Receipt{
nids: []int64{nid}, nid: nid,
}, nil }, nil
} }

View file

@ -33,16 +33,14 @@ func (d *Database) AssociateEDUWithDestination(
receipt *Receipt, receipt *Receipt,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, nid := range receipt.nids { if err := d.FederationSenderQueueEDUs.InsertQueueEDU(
if err := d.FederationSenderQueueEDUs.InsertQueueEDU( ctx, // context
ctx, // context txn, // SQL transaction
txn, // SQL transaction "", // TODO: EDU type for coalescing
"", // TODO: EDU type for coalescing serverName, // destination server name
serverName, // destination server name receipt.nid, // NID from the federationsender_queue_json table
nid, // NID from the federationsender_queue_json table ); err != nil {
); err != nil { return fmt.Errorf("InsertQueueEDU: %w", err)
return fmt.Errorf("InsertQueueEDU: %w", err)
}
} }
return nil return nil
}) })
@ -50,29 +48,25 @@ func (d *Database) AssociateEDUWithDestination(
// GetNextTransactionEDUs retrieves events from the database for // GetNextTransactionEDUs retrieves events from the database for
// the next pending transaction, up to the limit specified. // the next pending transaction, up to the limit specified.
func (d *Database) GetNextTransactionEDUs( func (d *Database) GetPendingEDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ( ) (
edus []*gomatrixserverlib.EDU, edus map[*Receipt]*gomatrixserverlib.EDU,
receipt *Receipt,
err error, err error,
) { ) {
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueEDUs: %w", err) return fmt.Errorf("SelectQueueEDUs: %w", err)
} }
receipt = &Receipt{
nids: nids,
}
retrieve := make([]int64, 0, len(nids)) retrieve := make([]int64, 0, len(nids))
for _, nid := range nids { for _, nid := range nids {
if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok { if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok {
edus = append(edus, edu) edus[&Receipt{nid}] = edu
} else { } else {
retrieve = append(retrieve, nid) retrieve = append(retrieve, nid)
} }
@ -83,12 +77,12 @@ func (d *Database) GetNextTransactionEDUs(
return fmt.Errorf("SelectQueueJSON: %w", err) return fmt.Errorf("SelectQueueJSON: %w", err)
} }
for _, blob := range blobs { for nid, blob := range blobs {
var event gomatrixserverlib.EDU var event gomatrixserverlib.EDU
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
edus = append(edus, &event) edus[&Receipt{nid}] = &event
} }
return nil return nil
@ -101,19 +95,24 @@ func (d *Database) GetNextTransactionEDUs(
func (d *Database) CleanEDUs( func (d *Database) CleanEDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
receipt *Receipt, receipts []*Receipt,
) error { ) error {
if receipt == nil { if len(receipts) == 0 {
return errors.New("expected receipt") return errors.New("expected receipt")
} }
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, receipt.nids); err != nil { if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, nids); err != nil {
return err return err
} }
var deleteNIDs []int64 var deleteNIDs []int64
for _, nid := range receipt.nids { for _, nid := range nids {
count, err := d.FederationSenderQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid) count, err := d.FederationSenderQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err) return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err)

View file

@ -34,16 +34,14 @@ func (d *Database) AssociatePDUWithDestination(
receipt *Receipt, receipt *Receipt,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, nid := range receipt.nids { if err := d.FederationSenderQueuePDUs.InsertQueuePDU(
if err := d.FederationSenderQueuePDUs.InsertQueuePDU( ctx, // context
ctx, // context txn, // SQL transaction
txn, // SQL transaction transactionID, // transaction ID
transactionID, // transaction ID serverName, // destination server name
serverName, // destination server name receipt.nid, // NID from the federationsender_queue_json table
nid, // NID from the federationsender_queue_json table ); err != nil {
); err != nil { return fmt.Errorf("InsertQueuePDU: %w", err)
return fmt.Errorf("InsertQueuePDU: %w", err)
}
} }
return nil return nil
}) })
@ -51,14 +49,12 @@ func (d *Database) AssociatePDUWithDestination(
// GetNextTransactionPDUs retrieves events from the database for // GetNextTransactionPDUs retrieves events from the database for
// the next pending transaction, up to the limit specified. // the next pending transaction, up to the limit specified.
func (d *Database) GetNextTransactionPDUs( func (d *Database) GetPendingPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ( ) (
transactionID gomatrixserverlib.TransactionID, events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
events []*gomatrixserverlib.HeaderedEvent,
receipt *Receipt,
err error, err error,
) { ) {
// Strictly speaking this doesn't need to be using the writer // Strictly speaking this doesn't need to be using the writer
@ -66,29 +62,17 @@ func (d *Database) GetNextTransactionPDUs(
// a guarantee of transactional isolation, it's actually useful // a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify // to know in SQLite mode that nothing else is trying to modify
// the database. // the database.
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
transactionID, err = d.FederationSenderQueuePDUs.SelectQueuePDUNextTransactionID(ctx, txn, serverName) nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil {
return fmt.Errorf("SelectQueuePDUNextTransactionID: %w", err)
}
if transactionID == "" {
return nil
}
nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, transactionID, limit)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueuePDUs: %w", err) return fmt.Errorf("SelectQueuePDUs: %w", err)
} }
receipt = &Receipt{
nids: nids,
}
retrieve := make([]int64, 0, len(nids)) retrieve := make([]int64, 0, len(nids))
for _, nid := range nids { for _, nid := range nids {
if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok { if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok {
events = append(events, event) events[&Receipt{nid}] = event
} else { } else {
retrieve = append(retrieve, nid) retrieve = append(retrieve, nid)
} }
@ -104,7 +88,7 @@ func (d *Database) GetNextTransactionPDUs(
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
events = append(events, &event) events[&Receipt{nid}] = &event
d.Cache.StoreFederationSenderQueuedPDU(nid, &event) d.Cache.StoreFederationSenderQueuedPDU(nid, &event)
} }
@ -119,19 +103,24 @@ func (d *Database) GetNextTransactionPDUs(
func (d *Database) CleanPDUs( func (d *Database) CleanPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
receipt *Receipt, receipts []*Receipt,
) error { ) error {
if receipt == nil { if len(receipts) == 0 {
return errors.New("expected receipt") return errors.New("expected receipt")
} }
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, receipt.nids); err != nil { if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, nids); err != nil {
return err return err
} }
var deleteNIDs []int64 var deleteNIDs []int64
for _, nid := range receipt.nids { for _, nid := range nids {
count, err := d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid) count, err := d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err) return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err)

View file

@ -53,10 +53,10 @@ const selectQueueNextTransactionIDSQL = "" +
" ORDER BY transaction_id ASC" + " ORDER BY transaction_id ASC" +
" LIMIT 1" " LIMIT 1"
const selectQueuePDUsByTransactionSQL = "" + const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" + "SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1 AND transaction_id = $2" + " WHERE server_name = $1" +
" LIMIT $3" " LIMIT $2"
const selectQueuePDUsReferenceJSONCountSQL = "" + const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
@ -73,7 +73,7 @@ type queuePDUsStatements struct {
db *sql.DB db *sql.DB
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt
@ -97,7 +97,7 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
return return
} }
if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
return return
} }
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
@ -193,11 +193,10 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -25,10 +25,9 @@ import (
type FederationSenderQueuePDUs interface { type FederationSenderQueuePDUs interface {
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueuePDUNextTransactionID(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (gomatrixserverlib.TransactionID, error)
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, limit int) ([]int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
} }

3
go.mod
View file

@ -7,7 +7,6 @@ require (
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect
github.com/lib/pq v1.8.0 github.com/lib/pq v1.8.0
github.com/libp2p/go-libp2p v0.11.0 github.com/libp2p/go-libp2p v0.11.0
github.com/libp2p/go-libp2p-circuit v0.3.1 github.com/libp2p/go-libp2p-circuit v0.3.1
@ -23,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20201204094806-e2d6b1a05ccb github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2

12
go.sum
View file

@ -301,8 +301,6 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d h1:68u9r4wEvL3gYg2jvAOgROwZ3H+Y3hIDk4tbbmIjcYQ= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d h1:68u9r4wEvL3gYg2jvAOgROwZ3H+Y3hIDk4tbbmIjcYQ=
github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfoUKUor0Hjgi2BJhCSXJfMOFlmyYrVKGQMk= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfoUKUor0Hjgi2BJhCSXJfMOFlmyYrVKGQMk=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -569,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201204094806-e2d6b1a05ccb h1:cN+v/rDbGg2p5e8xyxfATtwXeW7tI4aUn7slLXvuNyw= github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc h1:n2Hnbg8RZ4102Qmxie1riLkIyrqeqShJUILg1miSmDI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201204094806-e2d6b1a05ccb/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
@ -779,8 +777,6 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY= github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY=
@ -828,8 +824,6 @@ github.com/tidwall/pretty v1.0.2 h1:Z7S3cePv9Jwm1KwS0513MRaoUe3S01WPbLNV40pwWZU=
github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8= github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8=
github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y=
github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U=
github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs=
github.com/tidwall/sjson v1.1.2 h1:NC5okI+tQ8OG/oyzchvwXXxRxCV/FVdhODbPKkQ25jQ= github.com/tidwall/sjson v1.1.2 h1:NC5okI+tQ8OG/oyzchvwXXxRxCV/FVdhODbPKkQ25jQ=
github.com/tidwall/sjson v1.1.2/go.mod h1:SEzaDwxiPzKzNfUEO4HbYF/m4UCSJDsGgNqsS1LvdoY= github.com/tidwall/sjson v1.1.2/go.mod h1:SEzaDwxiPzKzNfUEO4HbYF/m4UCSJDsGgNqsS1LvdoY=
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U= github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
@ -912,8 +906,6 @@ golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 h1:Q7tZBpemrlsc2I7IyODzht
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o= golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=

View file

@ -0,0 +1,45 @@
package caching
import (
"github.com/matrix-org/dendrite/roomserver/types"
)
// WARNING: This cache is mutable because it's entirely possible that
// the IsStub or StateSnaphotNID fields can change, even though the
// room version and room NID fields will not. This is only safe because
// the RoomInfoCache is used ONLY within the roomserver and because it
// will be kept up-to-date by the latest events updater. It MUST NOT be
// used from other components as we currently have no way to invalidate
// the cache in downstream components.
const (
RoomInfoCacheName = "roominfo"
RoomInfoCacheMaxEntries = 1024
RoomInfoCacheMutable = true
)
// RoomInfosCache contains the subset of functions needed for
// a room Info cache. It must only be used from the roomserver only
// It is not safe for use from other components.
type RoomInfoCache interface {
GetRoomInfo(roomID string) (roomInfo types.RoomInfo, ok bool)
StoreRoomInfo(roomID string, roomInfo types.RoomInfo)
}
// GetRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) GetRoomInfo(roomID string) (types.RoomInfo, bool) {
val, found := c.RoomInfos.Get(roomID)
if found && val != nil {
if roomInfo, ok := val.(types.RoomInfo); ok {
return roomInfo, true
}
}
return types.RoomInfo{}, false
}
// StoreRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) StoreRoomInfo(roomID string, roomInfo types.RoomInfo) {
c.RoomInfos.Set(roomID, roomInfo)
}

View file

@ -15,10 +15,6 @@ const (
RoomServerEventTypeNIDsCacheMaxEntries = 64 RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomNIDsCacheName = "roomserver_room_nids"
RoomServerRoomNIDsCacheMaxEntries = 1024
RoomServerRoomNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false RoomServerRoomIDsCacheMutable = false
@ -27,6 +23,7 @@ const (
type RoomServerCaches interface { type RoomServerCaches interface {
RoomServerNIDsCache RoomServerNIDsCache
RoomVersionCache RoomVersionCache
RoomInfoCache
} }
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
@ -38,9 +35,6 @@ type RoomServerNIDsCache interface {
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomNID(roomID string) (types.RoomNID, bool)
StoreRoomServerRoomNID(roomID string, nid types.RoomNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
} }
@ -73,21 +67,6 @@ func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTyp
c.RoomServerEventTypeNIDs.Set(eventType, nid) c.RoomServerEventTypeNIDs.Set(eventType, nid)
} }
func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) {
val, found := c.RoomServerRoomNIDs.Get(roomID)
if found && val != nil {
if roomNID, ok := val.(types.RoomNID); ok {
return roomNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerRoomNID(roomID string, roomNID types.RoomNID) {
c.RoomServerRoomNIDs.Set(roomID, roomNID)
c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil { if found && val != nil {
@ -99,5 +78,5 @@ func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
} }
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
c.StoreRoomServerRoomNID(roomID, roomNID) c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
} }

View file

@ -10,6 +10,7 @@ type Caches struct {
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache FederationEvents Cache // FederationEventsCache
} }

View file

@ -45,19 +45,19 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerRoomNIDs, err := NewInMemoryLRUCachePartition( roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomNIDsCacheName, RoomServerRoomIDsCacheName,
RoomServerRoomNIDsCacheMutable, RoomServerRoomIDsCacheMutable,
RoomServerRoomNIDsCacheMaxEntries, RoomServerRoomIDsCacheMaxEntries,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerRoomIDs, err := NewInMemoryLRUCachePartition( roomInfos, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName, RoomInfoCacheName,
RoomServerRoomIDsCacheMutable, RoomInfoCacheMutable,
RoomServerRoomIDsCacheMaxEntries, RoomInfoCacheMaxEntries,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -77,8 +77,8 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
ServerKeys: serverKeys, ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs, RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs, RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomNIDs: roomServerRoomNIDs,
RoomServerRoomIDs: roomServerRoomIDs, RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents, FederationEvents: federationEvents,
}, nil }, nil
} }

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 3 VersionMinor = 3
VersionPatch = 2 VersionPatch = 5
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -319,7 +319,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
} }
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
requestTimeout := time.Minute // max amount of time we want to spend on each request requestTimeout := time.Second * 30 // max amount of time we want to spend on each request
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel() defer cancel()
logger := util.GetLogger(ctx).WithField("server_name", serverName) logger := util.GetLogger(ctx).WithField("server_name", serverName)

View file

@ -82,6 +82,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -83,6 +83,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -3,6 +3,7 @@ package api
import ( import (
"context" "context"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
) )
@ -11,6 +12,7 @@ type RoomserverInternalAPI interface {
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI)
SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI)
InputRoomEvents( InputRoomEvents(
ctx context.Context, ctx context.Context,

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -19,6 +20,10 @@ func (t *RoomserverInternalAPITrace) SetFederationSenderAPI(fsAPI fsAPI.Federati
t.Impl.SetFederationSenderAPI(fsAPI) t.Impl.SetFederationSenderAPI(fsAPI)
} }
func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
t.Impl.SetAppserviceAPI(asAPI)
}
func (t *RoomserverInternalAPITrace) InputRoomEvents( func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context, ctx context.Context,
req *InputRoomEventsRequest, req *InputRoomEventsRequest,

View file

@ -23,6 +23,8 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api"
) )
// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. // RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API.
@ -90,17 +92,13 @@ func (r *RoomserverInternalAPI) GetRoomIDForAlias(
return err return err
} }
/* if r.asAPI != nil { // appservice component is wired in
TODO: Why is this here? It creates an unnecessary dependency
from the roomserver to the appservice component, which should be
altogether optional.
if roomID == "" { if roomID == "" {
// No room found locally, try our application services by making a call to // No room found locally, try our application services by making a call to
// the appservice component // the appservice component
aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias} aliasReq := asAPI.RoomAliasExistsRequest{Alias: request.Alias}
var aliasResp appserviceAPI.RoomAliasExistsResponse var aliasResp asAPI.RoomAliasExistsResponse
if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil { if err = r.asAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil {
return err return err
} }
@ -111,7 +109,7 @@ func (r *RoomserverInternalAPI) GetRoomIDForAlias(
} }
} }
} }
*/ }
response.RoomID = roomID response.RoomID = roomID
return nil return nil

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
@ -35,6 +36,7 @@ type RoomserverInternalAPI struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
fsAPI fsAPI.FederationSenderInternalAPI fsAPI fsAPI.FederationSenderInternalAPI
asAPI asAPI.AppServiceQueryAPI
OutputRoomEventTopic string // Kafka topic for new output room events OutputRoomEventTopic string // Kafka topic for new output room events
PerspectiveServerNames []gomatrixserverlib.ServerName PerspectiveServerNames []gomatrixserverlib.ServerName
} }
@ -126,6 +128,10 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen
} }
} }
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
r.asAPI = asAPI
}
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,

View file

@ -54,10 +54,8 @@ type inputWorker struct {
input chan *inputTask input chan *inputTask
} }
// Guarded by a CAS on w.running
func (w *inputWorker) start() { func (w *inputWorker) start() {
if !w.running.CAS(false, true) {
return
}
defer w.running.Store(false) defer w.running.Store(false)
for { for {
select { select {
@ -118,7 +116,7 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
ctx context.Context, _ context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) { ) {
@ -142,7 +140,7 @@ func (r *Inputer) InputRoomEvents(
// room - the channel will be quite small as it's just pointer types. // room - the channel will be quite small as it's just pointer types.
w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ w, _ := r.workers.LoadOrStore(roomID, &inputWorker{
r: r, r: r,
input: make(chan *inputTask, 10), input: make(chan *inputTask, 32),
}) })
worker := w.(*inputWorker) worker := w.(*inputWorker)
@ -150,13 +148,15 @@ func (r *Inputer) InputRoomEvents(
// the wait group, so that the worker can notify us when this specific // the wait group, so that the worker can notify us when this specific
// task has been finished. // task has been finished.
tasks[i] = &inputTask{ tasks[i] = &inputTask{
ctx: ctx, ctx: context.Background(),
event: &request.InputRoomEvents[i], event: &request.InputRoomEvents[i],
wg: wg, wg: wg,
} }
// Send the task to the worker. // Send the task to the worker.
go worker.start() if worker.running.CAS(false, true) {
go worker.start()
}
worker.input <- tasks[i] worker.input <- tasks[i]
} }

View file

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -28,9 +29,29 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func init() {
prometheus.MustRegister(processRoomEventDuration)
}
var processRoomEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "roomserver",
Name: "processroomevent_duration_millis",
Help: "How long it takes the roomserver to process an event",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"room_id"},
)
// processRoomEvent can only be called once at a time // processRoomEvent can only be called once at a time
// //
// TODO(#375): This should be rewritten to allow concurrent calls. The // TODO(#375): This should be rewritten to allow concurrent calls. The
@ -42,6 +63,15 @@ func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) (eventID string, err error) { ) (eventID string, err error) {
// Measure how long it takes to process this event.
started := time.Now()
defer func() {
timetaken := time.Since(started)
processRoomEventDuration.With(prometheus.Labels{
"room_id": input.Event.RoomID(),
}).Observe(float64(timetaken.Milliseconds()))
}()
// Parse and validate the event JSON // Parse and validate the event JSON
headered := input.Event headered := input.Event
event := headered.Unwrap() event := headered.Unwrap()

View file

@ -259,40 +259,37 @@ func (u *latestEventsUpdater) calculateLatest(
// First of all, get a list of all of the events in our current // First of all, get a list of all of the events in our current
// set of forward extremities. // set of forward extremities.
existingRefs := make(map[string]*types.StateAtEventAndReference) existingRefs := make(map[string]*types.StateAtEventAndReference)
existingNIDs := make([]types.EventNID, len(oldLatest))
for i, old := range oldLatest { for i, old := range oldLatest {
existingRefs[old.EventID] = &oldLatest[i] existingRefs[old.EventID] = &oldLatest[i]
existingNIDs[i] = old.EventNID
}
// Look up the old extremity events. This allows us to find their
// prev events.
events, err := u.api.DB.Events(u.ctx, existingNIDs)
if err != nil {
return false, fmt.Errorf("u.api.DB.Events: %w", err)
}
// Make a list of all of the prev events as referenced by all of
// the current forward extremities.
existingPrevs := make(map[string]struct{})
for _, old := range events {
for _, prevEventID := range old.PrevEventIDs() {
existingPrevs[prevEventID] = struct{}{}
}
}
// If the "new" event is already referenced by a forward extremity
// then do nothing - it's not a candidate to be a new extremity if
// it has been referenced.
if _, ok := existingPrevs[newEvent.EventID()]; ok {
return false, nil
} }
// If the "new" event is already a forward extremity then stop, as // If the "new" event is already a forward extremity then stop, as
// nothing changes. // nothing changes.
for _, event := range events { if _, ok := existingRefs[newEvent.EventID()]; ok {
if event.EventID() == newEvent.EventID() { u.latest = oldLatest
return false, nil return false, nil
}
// If the "new" event is already referenced by an existing event
// then do nothing - it's not a candidate to be a new extremity if
// it has been referenced.
if referenced, err := u.updater.IsReferenced(newEvent.EventReference()); err != nil {
return false, fmt.Errorf("u.updater.IsReferenced(new): %w", err)
} else if referenced {
u.latest = oldLatest
return false, nil
}
// Then let's see if any of the existing forward extremities now
// have entries in the previous events table. If they do then we
// will no longer include them as forward extremities.
existingPrevs := make(map[string]struct{})
for _, l := range existingRefs {
referenced, err := u.updater.IsReferenced(l.EventReference)
if err != nil {
return false, fmt.Errorf("u.updater.IsReferenced: %w", err)
} else if referenced {
existingPrevs[l.EventID] = struct{}{}
} }
} }

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsInputAPI "github.com/matrix-org/dendrite/federationsender/api" fsInputAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
@ -84,6 +85,10 @@ func NewRoomserverClient(
func (h *httpRoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsInputAPI.FederationSenderInternalAPI) { func (h *httpRoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsInputAPI.FederationSenderInternalAPI) {
} }
// SetAppserviceAPI no-ops in HTTP client mode as there is no chicken/egg scenario
func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
}
// SetRoomAlias implements RoomserverAliasAPI // SetRoomAlias implements RoomserverAliasAPI
func (h *httpRoomserverInternalAPI) SetRoomAlias( func (h *httpRoomserverInternalAPI) SetRoomAlias(
ctx context.Context, ctx context.Context,

View file

@ -120,8 +120,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
const selectRoomNIDForEventNIDSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1" "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)"
type eventStatements struct { type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -137,7 +137,7 @@ type eventStatements struct {
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -161,7 +161,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -432,11 +432,24 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil return result, nil
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNID types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (roomNID types.RoomNID, err error) { ) (map[types.EventNID]types.RoomNID, error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
return if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
} }
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {

View file

@ -18,7 +18,6 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -69,8 +68,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" + const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
const selectRoomVersionForRoomNIDSQL = "" + const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid = ANY($1)"
const selectRoomInfoSQL = "" + const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -90,7 +89,7 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomVersionsForRoomNIDsStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt
bulkSelectRoomIDsStmt *sql.Stmt bulkSelectRoomIDsStmt *sql.Stmt
@ -109,7 +108,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
@ -219,15 +218,24 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) SelectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNIDs []types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err != nil {
if err == sql.ErrNoRows { return nil, err
return roomVersion, errors.New("room not found")
} }
return roomVersion, err defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
@ -271,3 +279,11 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []strin
} }
return roomNIDs, nil return roomNIDs, nil
} }
func roomNIDsAsArray(roomNIDs []types.RoomNID) pq.Int64Array {
nids := make([]int64, len(roomNIDs))
for i := range roomNIDs {
nids[i] = int64(roomNIDs[i])
}
return nids
}

View file

@ -105,6 +105,13 @@ func (u *LatestEventsUpdater) SetLatestEvents(
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
} }
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil return nil
}) })
} }

View file

@ -124,7 +124,15 @@ func (d *Database) StateEntriesForTuples(
} }
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.RoomsTable.SelectRoomInfo(ctx, roomID) if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
}
return roomInfo, err
} }
func (d *Database) AddState( func (d *Database) AddState(
@ -313,25 +321,39 @@ func (d *Database) Events(
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
results := make([]types.Event, len(eventJSONs)) var roomNIDs map[types.EventNID]types.RoomNID
for i, eventJSON := range eventJSONs { roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
var roomNID types.RoomNID if err != nil {
var roomVersion gomatrixserverlib.RoomVersion return nil, err
result := &results[i] }
result.EventNID = eventJSON.EventNID uniqueRoomNIDs := make(map[types.RoomNID]struct{})
roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID) for _, n := range roomNIDs {
if err != nil { uniqueRoomNIDs[n] = struct{}{}
return nil, err }
} roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
roomVersion, _ = d.Cache.GetRoomVersion(roomID) for n := range uniqueRoomNIDs {
} if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
if roomVersion == "" { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) roomVersions[n] = roomInfo.RoomVersion
if err != nil { continue
return nil, err
} }
} }
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
if err != nil {
return nil, err
}
for n, v := range dbRoomVersions {
roomVersions[n] = v
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID := roomNIDs[result.EventNID]
roomVersion := roomVersions[roomNID]
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
) )
@ -552,8 +574,8 @@ func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
if roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return roomNID, nil return roomInfo.RoomNID, nil
} }
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
@ -565,9 +587,6 @@ func (d *Database) assignRoomNID(
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
} }
} }
if err == nil {
d.Cache.StoreRoomServerRoomNID(roomID, roomNID)
}
return roomNID, err return roomNID, err
} }

View file

@ -95,8 +95,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDForEventNIDSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1" "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
type eventStatements struct { type eventStatements struct {
db *sql.DB db *sql.DB
@ -112,7 +112,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
@ -137,7 +137,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -480,11 +480,33 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil return result, nil
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNID types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (roomNID types.RoomNID, err error) { ) (map[types.EventNID]types.RoomNID, error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
return sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
} }
func eventNIDsAsArray(eventNIDs []types.EventNID) string { func eventNIDsAsArray(eventNIDs []types.EventNID) string {

View file

@ -19,7 +19,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
@ -60,8 +59,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" + const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
const selectRoomVersionForRoomNIDSQL = "" + const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomInfoSQL = "" + const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -82,9 +81,9 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt //selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt
} }
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -101,7 +100,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL},
}.Prepare(db) }.Prepare(db)
@ -223,15 +222,33 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) SelectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNIDs []types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) sqlPrep, err := s.db.Prepare(sqlStr)
if err == sql.ErrNoRows { if err != nil {
return roomVersion, errors.New("room not found") return nil, err
} }
return roomVersion, err iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {

View file

@ -10,8 +10,9 @@ import (
) )
type EventJSONPair struct { type EventJSONPair struct {
EventNID types.EventNID EventNID types.EventNID
EventJSON []byte RoomVersion gomatrixserverlib.RoomVersion
EventJSON []byte
} }
type EventJSON interface { type EventJSON interface {
@ -58,7 +59,7 @@ type Events interface {
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }
type Rooms interface { type Rooms interface {
@ -67,7 +68,7 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error) SelectRoomIDs(ctx context.Context) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)

View file

@ -3,7 +3,7 @@ package config
type MSCs struct { type MSCs struct {
Matrix *Global `yaml:"-"` Matrix *Global `yaml:"-"`
// The MSCs to enable, currently only `msc2836` is supported. // The MSCs to enable
MSCs []string `yaml:"mscs"` MSCs []string `yaml:"mscs"`
Database DatabaseOptions `yaml:"database"` Database DatabaseOptions `yaml:"database"`

View file

@ -0,0 +1,376 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package msc2946 'Spaces Summary' implements https://github.com/matrix-org/matrix-doc/pull/2946
package msc2946
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"github.com/gorilla/mux"
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
)
const (
ConstCreateEventContentKey = "org.matrix.msc1772.type"
ConstSpaceChildEventType = "org.matrix.msc1772.space.child"
ConstSpaceParentEventType = "org.matrix.msc1772.space.parent"
)
// SpacesRequest is the request body to POST /_matrix/client/r0/rooms/{roomID}/spaces
type SpacesRequest struct {
MaxRoomsPerSpace int `json:"max_rooms_per_space"`
Limit int `json:"limit"`
Batch string `json:"batch"`
}
// Defaults sets the request defaults
func (r *SpacesRequest) Defaults() {
r.Limit = 100
r.MaxRoomsPerSpace = -1
}
// SpacesResponse is the response body to POST /_matrix/client/r0/rooms/{roomID}/spaces
type SpacesResponse struct {
NextBatch string `json:"next_batch"`
// Rooms are nodes on the space graph.
Rooms []Room `json:"rooms"`
// Events are edges on the space graph, exclusively m.space.child or m.space.parent events
Events []gomatrixserverlib.ClientEvent `json:"events"`
}
// Room is a node on the space graph
type Room struct {
gomatrixserverlib.PublicRoom
NumRefs int `json:"num_refs"`
RoomType string `json:"room_type"`
}
// Enable this MSC
func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
) error {
db, err := NewDatabase(&base.Cfg.MSCs.Database)
if err != nil {
return fmt.Errorf("Cannot enable MSC2946: %w", err)
}
hooks.Enable()
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreReference(context.Background(), he)
if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreReference",
)
}
})
base.PublicClientAPIMux.Handle("/unstable/rooms/{roomID}/spaces",
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI)),
).Methods(http.MethodPost, http.MethodOptions)
return nil
}
func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
var r SpacesRequest
r.Defaults()
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
if r.Limit > 100 {
r.Limit = 100
}
w := walker{
req: &r,
rootRoomID: roomID,
caller: device,
ctx: req.Context(),
db: db,
rsAPI: rsAPI,
inMemoryBatchCache: inMemoryBatchCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
}
}
}
type walker struct {
req *SpacesRequest
rootRoomID string
caller *userapi.Device
db Database
rsAPI roomserver.RoomserverInternalAPI
ctx context.Context
// user ID|device ID|batch_num => event/room IDs sent to client
inMemoryBatchCache map[string]set
mu sync.Mutex
}
func (w *walker) alreadySent(id string) bool {
w.mu.Lock()
defer w.mu.Unlock()
m, ok := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID]
if !ok {
return false
}
return m[id]
}
func (w *walker) markSent(id string) {
w.mu.Lock()
defer w.mu.Unlock()
m := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID]
if m == nil {
m = make(set)
}
m[id] = true
w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID] = m
}
// nolint:gocyclo
func (w *walker) walk() *SpacesResponse {
var res SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
unvisited := []string{w.rootRoomID}
processed := make(set)
for len(unvisited) > 0 {
roomID := unvisited[0]
unvisited = unvisited[1:]
// If this room has already been processed, skip. NB: do not remember this between calls
if processed[roomID] || roomID == "" {
continue
}
// Mark this room as processed.
processed[roomID] = true
// Is the caller currently joined to the room or is the room `world_readable`
// If no, skip this room. If yes, continue.
if !w.authorised(roomID) {
continue
}
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`)
// this room. This requires servers to store reverse lookups.
refs, err := w.references(roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
continue
}
// If this room has not ever been in `rooms` (across multiple requests), extract the
// `PublicRoomsChunk` for this room.
if !w.alreadySent(roomID) {
pubRoom := w.publicRoomsChunk(roomID)
roomType := ""
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
res.Rooms = append(res.Rooms, Room{
PublicRoom: *pubRoom,
NumRefs: refs.len(),
RoomType: roomType,
})
}
uniqueRooms := make(set)
// If this is the root room from the original request, insert all these events into `events` if
// they haven't been added before (across multiple requests).
if w.rootRoomID == roomID {
for _, ev := range refs.events() {
if !w.alreadySent(ev.EventID()) {
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent(
ev, gomatrixserverlib.FormatAll,
))
uniqueRooms[ev.RoomID()] = true
uniqueRooms[SpaceTarget(ev)] = true
w.markSent(ev.EventID())
}
}
} else {
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
// are exceeded, stop adding events. If the event has already been added, do not add it again.
numAdded := 0
for _, ev := range refs.events() {
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
break
}
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
break
}
if w.alreadySent(ev.EventID()) {
continue
}
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent(
ev, gomatrixserverlib.FormatAll,
))
uniqueRooms[ev.RoomID()] = true
uniqueRooms[SpaceTarget(ev)] = true
w.markSent(ev.EventID())
// we don't distinguish between child state events and parent state events for the purposes of
// max_rooms_per_space, maybe we should?
numAdded++
}
}
// For each referenced room ID in the events being returned to the caller (both parent and child)
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
for roomID := range uniqueRooms {
unvisited = append(unvisited, roomID)
}
}
return &res
}
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
var queryRes roomserver.QueryCurrentStateResponse
tuple := gomatrixserverlib.StateKeyTuple{
EventType: evType,
StateKey: stateKey,
}
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
}, &queryRes)
if err != nil {
return nil
}
return queryRes.StateEvents[tuple]
}
func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
pubRooms, err := roomserver.PopulatePublicRooms(w.ctx, []string{roomID}, w.rsAPI)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to PopulatePublicRooms")
return nil
}
if len(pubRooms) == 0 {
return nil
}
return &pubRooms[0]
}
// authorised returns true iff the user is joined this room or the room is world_readable
func (w *walker) authorised(roomID string) bool {
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: "",
}
roomMemberTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomMember,
StateKey: w.caller.UserID,
}
var queryRes roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, roomMemberTuple,
},
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState")
return false
}
memberEv := queryRes.StateEvents[roomMemberTuple]
hisVisEv := queryRes.StateEvents[hisVisTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join {
return true
}
}
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true
}
}
return false
}
// references returns all references pointing to or from this room.
func (w *walker) references(roomID string) (eventLookup, error) {
events, err := w.db.References(w.ctx, roomID)
if err != nil {
return nil, err
}
el := make(eventLookup)
for _, ev := range events {
// only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link
// is in the state_key)
if gjson.GetBytes(ev.Content(), "via").Exists() {
el.set(ev)
}
}
return el, nil
}
// state event lookup across multiple rooms keyed on event type
// NOT THREAD SAFE
type eventLookup map[string][]*gomatrixserverlib.HeaderedEvent
func (el eventLookup) set(ev *gomatrixserverlib.HeaderedEvent) {
evs := el[ev.Type()]
if evs == nil {
evs = make([]*gomatrixserverlib.HeaderedEvent, 0)
}
evs = append(evs, ev)
el[ev.Type()] = evs
}
func (el eventLookup) len() int {
sum := 0
for _, evs := range el {
sum += len(evs)
}
return sum
}
func (el eventLookup) events() (events []*gomatrixserverlib.HeaderedEvent) {
for _, evs := range el {
events = append(events, evs...)
}
return
}
type set map[string]bool

View file

@ -0,0 +1,497 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946_test
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
client = &http.Client{
Timeout: 10 * time.Second,
}
)
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
// and a bit of recursion to subspaces. Makes a graph like:
// Root
// ____|_____
// | | |
// R1 R2 S1
// |_________
// | | |
// R3 R4 S2
// | <-- this link is just a parent, not a child
// R5
//
// Alice is not joined to R4, but R4 is "world_readable".
func TestMSC2946(t *testing.T) {
alice := "@alice:localhost"
// give access token to alice
nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device),
}
nopUserAPI.accessTokens["alice"] = userapi.Device{
AccessToken: "alice",
DisplayName: "Alice",
UserID: alice,
}
rootSpace := "!rootspace:localhost"
subSpaceS1 := "!subspaceS1:localhost"
subSpaceS2 := "!subspaceS2:localhost"
room1 := "!room1:localhost"
room2 := "!room2:localhost"
room3 := "!room3:localhost"
room4 := "!room4:localhost"
empty := ""
room5 := "!room5:localhost"
allRooms := []string{
rootSpace, subSpaceS1, subSpaceS2,
room1, room2, room3, room4, room5,
}
rootToR1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToR2 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToS1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR4 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room4,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToS2 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// This is a parent link only
s2ToR5 := mustCreateEvent(t, fledglingEvent{
RoomID: room5,
Sender: alice,
Type: msc2946.ConstSpaceParentEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// history visibility for R4
r4HisVis := mustCreateEvent(t, fledglingEvent{
RoomID: room4,
Sender: "@someone:localhost",
Type: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: &empty,
Content: map[string]interface{}{
"history_visibility": "world_readable",
},
})
var joinEvents []*gomatrixserverlib.HeaderedEvent
for _, roomID := range allRooms {
if roomID == room4 {
continue // not joined to that room
}
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
RoomID: roomID,
Sender: alice,
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
Content: map[string]interface{}{
"membership": "join",
},
}))
}
roomNameTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.name",
StateKey: "",
}
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.history_visibility",
StateKey: "",
}
nopRsAPI := &testRoomserverAPI{
joinEvents: joinEvents,
events: map[string]*gomatrixserverlib.HeaderedEvent{
rootToR1.EventID(): rootToR1,
rootToR2.EventID(): rootToR2,
rootToS1.EventID(): rootToS1,
s1ToR3.EventID(): s1ToR3,
s1ToR4.EventID(): s1ToR4,
s1ToS2.EventID(): s1ToS2,
s2ToR5.EventID(): s2ToR5,
r4HisVis.EventID(): r4HisVis,
},
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
rootSpace: {
roomNameTuple: "Root",
hisVisTuple: "shared",
},
subSpaceS1: {
roomNameTuple: "Sub-Space 1",
hisVisTuple: "joined",
},
subSpaceS2: {
roomNameTuple: "Sub-Space 2",
hisVisTuple: "shared",
},
room1: {
hisVisTuple: "joined",
},
room2: {
hisVisTuple: "joined",
},
room3: {
hisVisTuple: "joined",
},
room4: {
hisVisTuple: "world_readable",
},
room5: {
hisVisTuple: "joined",
},
},
}
allEvents := []*gomatrixserverlib.HeaderedEvent{
rootToR1, rootToR2, rootToS1,
s1ToR3, s1ToR4, s1ToS2,
s2ToR5, r4HisVis,
}
allEvents = append(allEvents, joinEvents...)
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
cancel := runServer(t, router)
defer cancel()
t.Run("returns no events for unknown rooms", func(t *testing.T) {
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
if len(res.Events) > 0 {
t.Errorf("got %d events, want 0", len(res.Events))
}
if len(res.Rooms) > 0 {
t.Errorf("got %d rooms, want 0", len(res.Rooms))
}
})
t.Run("returns the entire graph", func(t *testing.T) {
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 7 {
t.Errorf("got %d events, want 7", len(res.Events))
}
if len(res.Rooms) != len(allRooms) {
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
}
})
t.Run("can update the graph", func(t *testing.T) {
// remove R3 from the graph
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{}, // redacted
})
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 6 { // one less since we don't return redacted events
t.Errorf("got %d events, want 6", len(res.Events))
}
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
}
})
}
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2946.SpacesRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
var r msc2946.SpacesRequest
if err := json.Unmarshal(b, &r); err != nil {
t.Fatalf("Failed to unmarshal request: %s", err)
}
return &r
}
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{
Addr: string(":8010"),
WriteTimeout: 60 * time.Second,
Handler: router,
}
go func() {
externalServ.ListenAndServe()
}()
// wait to listen on the port
time.Sleep(500 * time.Millisecond)
return func() {
externalServ.Shutdown(context.TODO())
}
}
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *msc2946.SpacesRequest) *msc2946.SpacesResponse {
t.Helper()
var r msc2946.SpacesRequest
r.Defaults()
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %s", err)
}
httpReq, err := http.NewRequest(
"POST", "http://localhost:8010/_matrix/client/unstable/rooms/"+url.PathEscape(roomID)+"/spaces",
bytes.NewBuffer(data),
)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if err != nil {
t.Fatalf("failed to prepare request: %s", err)
}
res, err := client.Do(httpReq)
if err != nil {
t.Fatalf("failed to do request: %s", err)
}
if res.StatusCode != expectCode {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
}
if res.StatusCode == 200 {
var result msc2946.SpacesResponse
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
t.Logf("Body: %s", string(body))
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
}
return &result
}
return nil
}
type testUserAPI struct {
accessTokens map[string]userapi.Device
}
func (u *testUserAPI) InputAccountData(ctx context.Context, req *userapi.InputAccountDataRequest, res *userapi.InputAccountDataResponse) error {
return nil
}
func (u *testUserAPI) PerformAccountCreation(ctx context.Context, req *userapi.PerformAccountCreationRequest, res *userapi.PerformAccountCreationResponse) error {
return nil
}
func (u *testUserAPI) PerformPasswordUpdate(ctx context.Context, req *userapi.PerformPasswordUpdateRequest, res *userapi.PerformPasswordUpdateResponse) error {
return nil
}
func (u *testUserAPI) PerformDeviceCreation(ctx context.Context, req *userapi.PerformDeviceCreationRequest, res *userapi.PerformDeviceCreationResponse) error {
return nil
}
func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.PerformDeviceDeletionRequest, res *userapi.PerformDeviceDeletionResponse) error {
return nil
}
func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error {
return nil
}
func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
return nil
}
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
return nil
}
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil
}
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
dev, ok := u.accessTokens[req.AccessToken]
if !ok {
res.Err = fmt.Errorf("unknown token")
return nil
}
res.Device = &dev
return nil
}
func (u *testUserAPI) QueryDevices(ctx context.Context, req *userapi.QueryDevicesRequest, res *userapi.QueryDevicesResponse) error {
return nil
}
func (u *testUserAPI) QueryAccountData(ctx context.Context, req *userapi.QueryAccountDataRequest, res *userapi.QueryAccountDataResponse) error {
return nil
}
func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDeviceInfosRequest, res *userapi.QueryDeviceInfosResponse) error {
return nil
}
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.
roomserver.RoomserverInternalAPITrace
joinEvents []*gomatrixserverlib.HeaderedEvent
events map[string]*gomatrixserverlib.HeaderedEvent
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
}
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
for _, roomID := range req.RoomIDs {
pubRoomData, ok := r.pubRoomState[roomID]
if ok {
res.Rooms[roomID] = pubRoomData
}
}
return nil
}
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
if he.RoomID() != req.RoomID {
return
}
if he.StateKey() == nil {
return
}
tuple := gomatrixserverlib.StateKeyTuple{
EventType: he.Type(),
StateKey: *he.StateKey(),
}
for _, t := range req.StateTuples {
if t == tuple {
res.StateEvents[t] = he
}
}
}
for _, he := range r.joinEvents {
checkEvent(he)
}
for _, he := range r.events {
checkEvent(he)
}
return nil
}
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
t.Helper()
cfg := &config.Dendrite{}
cfg.Defaults()
cfg.Global.ServerName = "localhost"
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
cfg.MSCs.MSCs = []string{"msc2946"}
base := &setup.BaseDendrite{
Cfg: cfg,
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
}
err := msc2946.Enable(base, rsAPI, userAPI)
if err != nil {
t.Fatalf("failed to enable MSC2946: %s", err)
}
for _, ev := range events {
hooks.Run(hooks.KindNewEventPersisted, ev)
}
return base.PublicClientAPIMux
}
type fledglingEvent struct {
Type string
StateKey *string
Content interface{}
Sender string
RoomID string
}
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
t.Helper()
roomVer := gomatrixserverlib.RoomVersionV6
seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.EventBuilder{
Sender: ev.Sender,
Depth: 999,
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: ev.RoomID,
}
err := eb.SetContent(ev.Content)
if err != nil {
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
}
// make sure the origin_server_ts changes so we can test recency
time.Sleep(1 * time.Millisecond)
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
h := signedEvent.Headered(roomVer)
return h
}

View file

@ -0,0 +1,182 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
var (
relTypes = map[string]int{
ConstSpaceChildEventType: 1,
ConstSpaceParentEventType: 2,
}
)
type Database interface {
// StoreReference persists a child or parent space mapping.
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
// References returns all events which have the given roomID as a parent or child space.
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
selectEdgesStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(dbOpts)
}
return newSQLiteDatabase(dbOpts)
}
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewDummyWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
target := SpaceTarget(he)
if target == "" {
return nil // malformed event
}
relType := relTypes[he.Type()]
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
return err
}
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
for rows.Next() {
var roomVer string
var jsonBytes []byte
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
if err != nil {
return nil, err
}
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
refs = append(refs, he)
}
return refs, nil
}
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
// depending on the event type.
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
if he.StateKey() == nil {
return "" // no-op
}
switch he.Type() {
case ConstSpaceParentEventType:
return *he.StateKey()
case ConstSpaceChildEventType:
return *he.StateKey()
}
return ""
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/mscs/msc2836" "github.com/matrix-org/dendrite/setup/mscs/msc2836"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -39,6 +40,8 @@ func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) e
switch msc { switch msc {
case "msc2836": case "msc2836":
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing) return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing)
case "msc2946":
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI)
default: default:
return fmt.Errorf("EnableMSC: unknown msc '%s'", msc) return fmt.Errorf("EnableMSC: unknown msc '%s'", msc)
} }

View file

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -32,15 +32,17 @@ import (
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
clientAPIConsumer *internal.ContinualConsumer clientAPIConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *sync.Notifier stream types.StreamProvider
notifier *notifier.Notifier
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
func NewOutputClientDataConsumer( func NewOutputClientDataConsumer(
cfg *config.SyncAPI, cfg *config.SyncAPI,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database, store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -52,7 +54,8 @@ func NewOutputClientDataConsumer(
s := &OutputClientDataConsumer{ s := &OutputClientDataConsumer{
clientAPIConsumer: &consumer, clientAPIConsumer: &consumer,
db: store, db: store,
notifier: n, notifier: notifier,
stream: stream,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -81,7 +84,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
"room_id": output.RoomID, "room_id": output.RoomID,
}).Info("received data from client API server") }).Info("received data from client API server")
pduPos, err := s.db.UpsertAccountData( streamPos, err := s.db.UpsertAccountData(
context.TODO(), string(msg.Key), output.RoomID, output.Type, context.TODO(), string(msg.Key), output.RoomID, output.Type,
) )
if err != nil { if err != nil {
@ -92,7 +95,8 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil)) s.stream.Advance(streamPos)
s.notifier.OnNewAccountData(string(msg.Key), types.StreamingToken{AccountDataPosition: streamPos})
return nil return nil
} }

View file

@ -18,14 +18,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -33,7 +32,8 @@ import (
type OutputReceiptEventConsumer struct { type OutputReceiptEventConsumer struct {
receiptConsumer *internal.ContinualConsumer receiptConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *sync.Notifier stream types.StreamProvider
notifier *notifier.Notifier
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -41,8 +41,9 @@ type OutputReceiptEventConsumer struct {
func NewOutputReceiptEventConsumer( func NewOutputReceiptEventConsumer(
cfg *config.SyncAPI, cfg *config.SyncAPI,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database, store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -55,7 +56,8 @@ func NewOutputReceiptEventConsumer(
s := &OutputReceiptEventConsumer{ s := &OutputReceiptEventConsumer{
receiptConsumer: &consumer, receiptConsumer: &consumer,
db: store, db: store,
notifier: n, notifier: notifier,
stream: stream,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -87,8 +89,9 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro
if err != nil { if err != nil {
return err return err
} }
// update stream position
s.notifier.OnNewReceipt(types.NewStreamToken(0, streamPos, nil)) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return nil return nil
} }

View file

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -35,7 +35,8 @@ type OutputSendToDeviceEventConsumer struct {
sendToDeviceConsumer *internal.ContinualConsumer sendToDeviceConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
notifier *sync.Notifier stream types.StreamProvider
notifier *notifier.Notifier
} }
// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer.
@ -43,8 +44,9 @@ type OutputSendToDeviceEventConsumer struct {
func NewOutputSendToDeviceEventConsumer( func NewOutputSendToDeviceEventConsumer(
cfg *config.SyncAPI, cfg *config.SyncAPI,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database, store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputSendToDeviceEventConsumer { ) *OutputSendToDeviceEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -58,7 +60,8 @@ func NewOutputSendToDeviceEventConsumer(
sendToDeviceConsumer: &consumer, sendToDeviceConsumer: &consumer,
db: store, db: store,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
notifier: n, notifier: notifier,
stream: stream,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -94,20 +97,19 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
"event_type": output.Type, "event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server") }).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice() streamPos, err := s.db.StoreNewSendForDeviceMessage(
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent,
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
) )
if err != nil { if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message") log.WithError(err).Errorf("failed to store send-to-device message")
return err return err
} }
s.stream.Advance(streamPos)
s.notifier.OnNewSendToDevice( s.notifier.OnNewSendToDevice(
output.UserID, output.UserID,
[]string{output.DeviceID}, []string{output.DeviceID},
types.NewStreamToken(0, streamPos, nil), types.StreamingToken{SendToDevicePosition: streamPos},
) )
return nil return nil

View file

@ -19,10 +19,11 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -30,8 +31,9 @@ import (
// OutputTypingEventConsumer consumes events that originated in the EDU server. // OutputTypingEventConsumer consumes events that originated in the EDU server.
type OutputTypingEventConsumer struct { type OutputTypingEventConsumer struct {
typingConsumer *internal.ContinualConsumer typingConsumer *internal.ContinualConsumer
db storage.Database eduCache *cache.EDUCache
notifier *sync.Notifier stream types.StreamProvider
notifier *notifier.Notifier
} }
// NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer. // NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer.
@ -39,8 +41,10 @@ type OutputTypingEventConsumer struct {
func NewOutputTypingEventConsumer( func NewOutputTypingEventConsumer(
cfg *config.SyncAPI, cfg *config.SyncAPI,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database, store storage.Database,
eduCache *cache.EDUCache,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputTypingEventConsumer { ) *OutputTypingEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -52,8 +56,9 @@ func NewOutputTypingEventConsumer(
s := &OutputTypingEventConsumer{ s := &OutputTypingEventConsumer{
typingConsumer: &consumer, typingConsumer: &consumer,
db: store, eduCache: eduCache,
notifier: n, notifier: notifier,
stream: stream,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -63,13 +68,10 @@ func NewOutputTypingEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.eduCache.SetTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent( pos := types.StreamPosition(latestSyncPosition)
nil, roomID, nil, s.notifier.OnNewTyping(roomID, types.StreamingToken{TypingPosition: pos})
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
)
}) })
return s.typingConsumer.Start() return s.typingConsumer.Start()
} }
@ -90,11 +92,17 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
var typingPos types.StreamPosition var typingPos types.StreamPosition
typingEvent := output.Event typingEvent := output.Event
if typingEvent.Typing { if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) typingPos = types.StreamPosition(
s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime),
)
} else { } else {
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = types.StreamPosition(
s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID),
)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil)) s.stream.Advance(typingPos)
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
return nil return nil
} }

View file

@ -23,9 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
syncinternal "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
syncapi "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -35,12 +34,13 @@ import (
type OutputKeyChangeEventConsumer struct { type OutputKeyChangeEventConsumer struct {
keyChangeConsumer *internal.ContinualConsumer keyChangeConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *notifier.Notifier
stream types.PartitionedStreamProvider
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI
keyAPI api.KeyInternalAPI keyAPI api.KeyInternalAPI
partitionToOffset map[int32]int64 partitionToOffset map[int32]int64
partitionToOffsetMu sync.Mutex partitionToOffsetMu sync.Mutex
notifier *syncapi.Notifier
} }
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
@ -49,10 +49,11 @@ func NewOutputKeyChangeEventConsumer(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
topic string, topic string,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *syncapi.Notifier,
keyAPI api.KeyInternalAPI, keyAPI api.KeyInternalAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
store storage.Database, store storage.Database,
notifier *notifier.Notifier,
stream types.PartitionedStreamProvider,
) *OutputKeyChangeEventConsumer { ) *OutputKeyChangeEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -70,7 +71,8 @@ func NewOutputKeyChangeEventConsumer(
rsAPI: rsAPI, rsAPI: rsAPI,
partitionToOffset: make(map[int32]int64), partitionToOffset: make(map[int32]int64),
partitionToOffsetMu: sync.Mutex{}, partitionToOffsetMu: sync.Mutex{},
notifier: n, notifier: notifier,
stream: stream,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -113,15 +115,17 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
log.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") log.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
return err return err
} }
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams // make sure we get our own key updates too!
posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{ queryRes.UserIDsToCount[output.UserID] = 1
syncinternal.DeviceListLogName: { posUpdate := types.LogPosition{
Offset: msg.Offset, Offset: msg.Offset,
Partition: msg.Partition, Partition: msg.Partition,
},
})
for userID := range queryRes.UserIDsToCount {
s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
} }
s.stream.Advance(posUpdate)
for userID := range queryRes.UserIDsToCount {
s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID)
}
return nil return nil
} }

View file

@ -23,29 +23,32 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
cfg *config.SyncAPI cfg *config.SyncAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
rsConsumer *internal.ContinualConsumer rsConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *sync.Notifier pduStream types.StreamProvider
inviteStream types.StreamProvider
notifier *notifier.Notifier
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
cfg *config.SyncAPI, cfg *config.SyncAPI,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database, store storage.Database,
notifier *notifier.Notifier,
pduStream types.StreamProvider,
inviteStream types.StreamProvider,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
@ -56,11 +59,13 @@ func NewOutputRoomEventConsumer(
PartitionStore: store, PartitionStore: store,
} }
s := &OutputRoomEventConsumer{ s := &OutputRoomEventConsumer{
cfg: cfg, cfg: cfg,
rsConsumer: &consumer, rsConsumer: &consumer,
db: store, db: store,
notifier: n, notifier: notifier,
rsAPI: rsAPI, pduStream: pduStream,
inviteStream: inviteStream,
rsAPI: rsAPI,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -177,11 +182,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err return err
} }
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.pduStream.Advance(pduPos)
s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -216,11 +222,12 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
} }
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err return err
} }
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.pduStream.Advance(pduPos)
s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -259,6 +266,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
if msg.Event.StateKey() == nil {
log.WithFields(log.Fields{
"event": string(msg.Event.JSON()),
}).Panicf("roomserver output log: invite has no state key")
return nil
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -269,14 +282,17 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, *msg.Event.StateKey())
return nil return nil
} }
func (s *OutputRoomEventConsumer) onRetireInviteEvent( func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent, ctx context.Context, msg api.OutputRetireInviteEvent,
) error { ) error {
sp, err := s.db.RetireInviteEvent(ctx, msg.EventID) pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -285,9 +301,11 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
}).Panicf("roomserver output log: remove invite failure") }).Panicf("roomserver output log: remove invite failure")
return nil return nil
} }
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs s.inviteStream.Advance(pduPos)
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil)) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
return nil return nil
} }
@ -302,12 +320,13 @@ func (s *OutputRoomEventConsumer) onNewPeek(
}).Panicf("roomserver output log: write peek failure") }).Panicf("roomserver output log: write peek failure")
return nil return nil
} }
// tell the notifier about the new peek so it knows to wake up new devices
s.notifier.OnNewPeek(msg.RoomID, msg.UserID, msg.DeviceID)
// we need to wake up the users who might need to now be peeking into this room, // tell the notifier about the new peek so it knows to wake up new devices
// so we send in a dummy event to trigger a wakeup // TODO: This only works because the peeks table is reusing the same
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil)) // index as PDUs, but we should fix this
s.pduStream.Advance(sp)
s.notifier.OnNewPeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp})
return nil return nil
} }
@ -322,12 +341,13 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
}).Panicf("roomserver output log: write peek failure") }).Panicf("roomserver output log: write peek failure")
return nil return nil
} }
// tell the notifier about the new peek so it knows to wake up new devices
s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID)
// we need to wake up the users who might need to now be peeking into this room, // tell the notifier about the new peek so it knows to wake up new devices
// so we send in a dummy event to trigger a wakeup // TODO: This only works because the peeks table is reusing the same
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil)) // index as PDUs, but we should fix this
s.pduStream.Advance(sp)
s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp})
return nil return nil
} }

View file

@ -49,8 +49,8 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
// nolint:gocyclo // nolint:gocyclo
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
userID string, res *types.Response, from, to types.StreamingToken, userID string, res *types.Response, from, to types.LogPosition,
) (hasNew bool, err error) { ) (newPos types.LogPosition, hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not. // Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID) newlyJoinedRooms := joinedRooms(res, userID)
@ -58,7 +58,7 @@ func DeviceListCatchup(
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 { if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
changed, left, err := TrackChangedUsers(ctx, rsAPI, userID, newlyJoinedRooms, newlyLeftRooms) changed, left, err := TrackChangedUsers(ctx, rsAPI, userID, newlyJoinedRooms, newlyLeftRooms)
if err != nil { if err != nil {
return false, err return to, false, err
} }
res.DeviceLists.Changed = changed res.DeviceLists.Changed = changed
res.DeviceLists.Left = left res.DeviceLists.Left = left
@ -73,15 +73,13 @@ func DeviceListCatchup(
offset = sarama.OffsetOldest offset = sarama.OffsetOldest
// Extract partition/offset from sync token // Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := from.Log(DeviceListLogName) if !from.IsEmpty() {
if logOffset != nil { partition = from.Partition
partition = logOffset.Partition offset = from.Offset
offset = logOffset.Offset
} }
var toOffset int64 var toOffset int64
toOffset = sarama.OffsetNewest toOffset = sarama.OffsetNewest
toLog := to.Log(DeviceListLogName) if toLog := to; toLog.Partition == partition && toLog.Offset > 0 {
if toLog != nil && toLog.Offset > 0 {
toOffset = toLog.Offset toOffset = toLog.Offset
} }
var queryRes api.QueryKeyChangesResponse var queryRes api.QueryKeyChangesResponse
@ -93,7 +91,7 @@ func DeviceListCatchup(
if queryRes.Error != nil { if queryRes.Error != nil {
// don't fail the catchup because we may have got useful information by tracking membership // don't fail the catchup because we may have got useful information by tracking membership
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return hasNew, nil return to, hasNew, nil
} }
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
var sharedUsersMap map[string]int var sharedUsersMap map[string]int
@ -130,13 +128,12 @@ func DeviceListCatchup(
} }
} }
// set the new token // set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{ to = types.LogPosition{
Partition: queryRes.Partition, Partition: queryRes.Partition,
Offset: queryRes.Offset, Offset: queryRes.Offset,
}) }
res.NextBatch = to.String()
return hasNew, nil return to, hasNew, nil
} }
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.

View file

@ -16,13 +16,11 @@ import (
var ( var (
syncingUser = "@alice:localhost" syncingUser = "@alice:localhost"
emptyToken = types.NewStreamToken(0, 0, nil) emptyToken = types.LogPosition{}
newestToken = types.NewStreamToken(0, 0, map[string]*types.LogPosition{ newestToken = types.LogPosition{
DeviceListLogName: { Offset: sarama.OffsetNewest,
Offset: sarama.OffsetNewest, Partition: 0,
Partition: 0, }
},
})
) )
type mockKeyAPI struct{} type mockKeyAPI struct{}
@ -180,7 +178,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -203,7 +201,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -226,7 +224,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -248,7 +246,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -307,7 +305,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser}, roomID: {syncingUser, existingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -335,7 +333,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -420,7 +418,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
hasNew, err := DeviceListCatchup( _, hasNew, err := DeviceListCatchup(
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken, context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken,
) )
if err != nil { if err != nil {

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package sync package notifier
import ( import (
"context" "context"
@ -48,9 +48,9 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.StreamingToken) *Notifier { func NewNotifier(currPos types.StreamingToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: currPos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@ -77,9 +77,8 @@ func (n *Notifier) OnNewEvent(
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.currPos.ApplyUpdates(posUpdate)
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
if ev != nil { if ev != nil {
@ -113,11 +112,11 @@ func (n *Notifier) OnNewEvent(
} }
} }
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos) n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
} else if roomID != "" { } else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos) n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
} else if len(userIDs) > 0 { } else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, nil, latestPos) n.wakeupUsers(userIDs, nil, n.currPos)
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"posUpdate": posUpdate.String, "posUpdate": posUpdate.String,
@ -125,12 +124,24 @@ func (n *Notifier) OnNewEvent(
} }
} }
func (n *Notifier) OnNewPeek( func (n *Notifier) OnNewAccountData(
roomID, userID, deviceID string, userID string, posUpdate types.StreamingToken,
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{userID}, nil, posUpdate)
}
func (n *Notifier) OnNewPeek(
roomID, userID, deviceID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.addPeekingDevice(roomID, userID, deviceID) n.addPeekingDevice(roomID, userID, deviceID)
// we don't wake up devices here given the roomserver consumer will do this shortly afterwards // we don't wake up devices here given the roomserver consumer will do this shortly afterwards
@ -139,10 +150,12 @@ func (n *Notifier) OnNewPeek(
func (n *Notifier) OnRetirePeek( func (n *Notifier) OnRetirePeek(
roomID, userID, deviceID string, roomID, userID, deviceID string,
posUpdate types.StreamingToken,
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.removePeekingDevice(roomID, userID, deviceID) n.removePeekingDevice(roomID, userID, deviceID)
// we don't wake up devices here given the roomserver consumer will do this shortly afterwards // we don't wake up devices here given the roomserver consumer will do this shortly afterwards
@ -155,20 +168,33 @@ func (n *Notifier) OnNewSendToDevice(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos) n.currPos.ApplyUpdates(posUpdate)
n.wakeupUserDevice(userID, deviceIDs, n.currPos)
} }
// OnNewReceipt updates the current position // OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt( func (n *Notifier) OnNewTyping(
roomID string,
posUpdate types.StreamingToken, posUpdate types.StreamingToken,
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
}
// OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt(
roomID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
} }
func (n *Notifier) OnNewKeyChange( func (n *Notifier) OnNewKeyChange(
@ -176,15 +202,25 @@ func (n *Notifier) OnNewKeyChange(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, latestPos) n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
func (n *Notifier) OnNewInvite(
posUpdate types.StreamingToken, wakeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
} }
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos
func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener { func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
// Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298
// - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID
// - Incoming events wake requests for a matching room ID // - Incoming events wake requests for a matching room ID
@ -198,7 +234,7 @@ func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener {
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx) return n.fetchUserDeviceStream(req.Device.UserID, req.Device.ID, true).GetListener(req.Context)
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package sync package notifier
import ( import (
"context" "context"
@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld = types.NewStreamToken(5, 0, nil) syncPositionVeryOld = types.StreamingToken{PDUPosition: 5}
syncPositionBefore = types.NewStreamToken(11, 0, nil) syncPositionBefore = types.StreamingToken{PDUPosition: 11}
syncPositionAfter = types.NewStreamToken(12, 0, nil) syncPositionAfter = types.StreamingToken{PDUPosition: 12}
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil) //syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition, 1, 0, 0, nil)
syncPositionAfter2 = types.NewStreamToken(13, 0, nil) syncPositionAfter2 = types.StreamingToken{PDUPosition: 13}
) )
var ( var (
@ -205,6 +205,9 @@ func TestNewInviteEventForUser(t *testing.T) {
} }
// Test an EDU-only update wakes up the request. // Test an EDU-only update wakes up the request.
// TODO: Fix this test, invites wake up with an incremented
// PDU position, not EDU position
/*
func TestEDUWakeup(t *testing.T) { func TestEDUWakeup(t *testing.T) {
n := NewNotifier(syncPositionAfter) n := NewNotifier(syncPositionAfter)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
@ -229,6 +232,7 @@ func TestEDUWakeup(t *testing.T) {
wg.Wait() wg.Wait()
} }
*/
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
@ -322,16 +326,16 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { func waitForEvents(n *Notifier, req types.SyncRequest) (types.StreamingToken, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.StreamingToken{}, fmt.Errorf( return types.StreamingToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.Device.UserID, req.Since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(req.Since):
p := listener.GetSyncPosition() p := listener.GetSyncPosition()
return p, nil return p, nil
} }
@ -354,17 +358,17 @@ func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStre
return n.fetchUserDeviceStream(userID, deviceID, true) return n.fetchUserDeviceStream(userID, deviceID, true)
} }
func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest { func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) types.SyncRequest {
return syncRequest{ return types.SyncRequest{
device: userapi.Device{ Device: &userapi.Device{
UserID: userID, UserID: userID,
ID: deviceID, ID: deviceID,
}, },
timeout: 1 * time.Minute, Timeout: 1 * time.Minute,
since: &since, Since: since,
wantFullState: false, WantFullState: false,
limit: DefaultTimelineLimit, Limit: 20,
log: util.GetLogger(context.TODO()), Log: util.GetLogger(context.TODO()),
ctx: context.TODO(), Context: context.TODO(),
} }
} }

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package sync package notifier
import ( import (
"context" "context"

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,9 +50,10 @@ type messagesReq struct {
} }
type messagesResp struct { type messagesResp struct {
Start string `json:"start"` Start string `json:"start"`
End string `json:"end"` StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
} }
const defaultMessagesLimit = 10 const defaultMessagesLimit = 10
@ -65,6 +67,7 @@ func OnIncomingMessagesRequest(
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI, cfg *config.SyncAPI,
srp *sync.RequestPool,
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
@ -84,9 +87,18 @@ func OnIncomingMessagesRequest(
// Extract parameters from the request's URL. // Extract parameters from the request's URL.
// Pagination tokens. // Pagination tokens.
var fromStream *types.StreamingToken var fromStream *types.StreamingToken
from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from")) fromQuery := req.URL.Query().Get("from")
emptyFromSupplied := fromQuery == ""
if emptyFromSupplied {
// NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided.
// We do this to allow clients to get messages without having to call `/sync` e.g Cerulean
currPos := srp.Notifier.CurrentPosition()
fromQuery = currPos.String()
}
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil { if err != nil {
fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from")) fs, err2 := types.NewStreamTokenFromString(fromQuery)
fromStream = &fs fromStream = &fs
if err2 != nil { if err2 != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -185,14 +197,19 @@ func OnIncomingMessagesRequest(
"return_end": end.String(), "return_end": end.String(),
}).Info("Responding") }).Info("Responding")
res := messagesResp{
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
}
if emptyFromSupplied {
res.StartStream = fromStream.String()
}
// Respond with the events. // Respond with the events.
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: messagesResp{ JSON: res,
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
},
} }
} }
@ -256,6 +273,14 @@ func (r *messagesReq) retrieveEvents() (
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
} }
// Get the position of the first and the last event in the room's topology.
// This position is currently determined by the event's depth, so we could
// also use it instead of retrieving from the database. However, if we ever
// change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
// Sort the events to ensure we send them in the right order. // Sort the events to ensure we send them in the right order.
if r.backwardOrdering { if r.backwardOrdering {
// This reverses the array from old->new to new->old // This reverses the array from old->new to new->old
@ -275,14 +300,6 @@ func (r *messagesReq) retrieveEvents() (
// Convert all of the events into client events. // Convert all of the events into client events.
clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
// Get the position of the first and the last event in the room's topology.
// This position is currently determined by the event's depth, so we could
// also use it instead of retrieving from the database. However, if we ever
// change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
return clientEvents, start, end, err return clientEvents, start, end, err
} }
@ -346,7 +363,7 @@ func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedE
return events // apply no filtering as it defaults to Shared. return events // apply no filtering as it defaults to Shared.
} }
hisVis, _ := hisVisEvent.HistoryVisibility() hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" { if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering return events // apply no filtering
} }
if membershipEvent == nil { if membershipEvent == nil {
@ -371,26 +388,16 @@ func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedE
} }
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
start, err = r.db.EventPositionInTopology( if r.backwardOrdering {
r.ctx, events[0].EventID(), start = *r.from
) if events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate {
if err != nil { // NOTSPEC: We've hit the beginning of the room so there's really nowhere
err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) // else to go. This seems to fix Riot iOS from looping on /messages endlessly.
return end = types.TopologyToken{}
} } else {
if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { end, err = r.db.EventPositionInTopology(
// We've hit the beginning of the room so there's really nowhere else r.ctx, events[0].EventID(),
// to go. This seems to fix Riot iOS from looping on /messages endlessly. )
end = types.NewTopologyToken(0, 0)
} else {
end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
return
}
if r.backwardOrdering {
// A stream/topological position is a cursor located between two events. // A stream/topological position is a cursor located between two events.
// While they are identified in the code by the event on their right (if // While they are identified in the code by the event on their right (if
// we consider a left to right chronological order), tokens need to refer // we consider a left to right chronological order), tokens need to refer
@ -398,6 +405,15 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
// end position we send in the response if we're going backward. // end position we send in the response if we're going backward.
end.Decrement() end.Decrement()
} }
} else {
start = *r.from
end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
}
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
return
} }
return return
} }
@ -447,11 +463,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
// The condition in the SQL query is a strict "greater than" so // The condition in the SQL query is a strict "greater than" so
// we need to check against to-1. // we need to check against to-1.
streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos) isSetLargeEnough = (r.to.PDUPosition-1 == streamPos)
} }
} else { } else {
streamPos := types.StreamPosition(streamEvents[0].StreamPosition) streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos) isSetLargeEnough = (r.from.PDUPosition-1 == streamPos)
} }
} }
@ -565,7 +581,7 @@ func setToDefault(
if backwardOrdering { if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event // go 1 earlier than the first event so we correctly fetch the earliest event
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.NewTopologyToken(0, 0) to = types.TopologyToken{}
} else { } else {
to, err = db.MaxTopologicalPosition(ctx, roomID) to, err = db.MaxTopologicalPosition(ctx, roomID)
} }

View file

@ -51,7 +51,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter", r0mux.Handle("/user/{userId}/filter",

View file

@ -16,11 +16,9 @@ package storage
import ( import (
"context" "context"
"time"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -30,6 +28,27 @@ import (
type Database interface { type Database interface {
internal.PartitionStorer internal.PartitionStorer
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error)
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
@ -56,18 +75,6 @@ type Database interface {
// Returns an empty slice if no state events could be found for this room. // Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval. // Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
// SyncPosition returns the latest positions for syncing.
SyncPosition(ctx context.Context) (types.StreamingToken, error)
// IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID. A response object must be provided for IncrementaSync to populate - it
// will not create one.
IncrementalSync(ctx context.Context, res *types.Response, device userapi.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
// CompleteSync returns a complete /sync API response for the given user. A response object
// must be provided for CompleteSync to populate - it will not create one.
CompleteSync(ctx context.Context, res *types.Response, device userapi.Device, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions // updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
@ -97,15 +104,6 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user // DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// SetTypingTimeoutCallback sets a callback function that is called right after
// a user is removed from the typing user list due to timeout.
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
// AddTypingUser adds a typing user to the typing cache.
// Returns the newly calculated sync position for typing notifications.
AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition
// RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications.
RemoveTypingUser(userID, roomID string) types.StreamPosition
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
@ -120,28 +118,14 @@ type Database interface {
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// AddSendToDevice increases the EDU position in the cache and returns the stream position. // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
AddSendToDevice() types.StreamPosition // relevant events within the given ranges for the supplied user ID and device ID.
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// - "events": a list of send-to-device events that should be included in the sync
// - "changes": a list of send-to-device events that should be updated in the database by
// CleanSendToDeviceUpdates
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows // from position, preventing the send-to-device table from growing indefinitely.
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists // Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.

View file

@ -58,6 +58,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users -- for querying membership states of users
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -76,7 +78,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
@ -92,10 +94,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly five columns. We need to // the rowsToStreamEvents expects there to be exactly six columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)" " FROM syncapi_current_room_state WHERE event_id = ANY($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
@ -278,13 +280,14 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{} result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var eventID string
var eventBytes []byte var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil { if err := rows.Scan(&eventID, &eventBytes); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }
result = append(result, &ev) result = append(result, &ev)

View file

@ -0,0 +1,67 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Use the new syncapi_receipts_id sequence.
CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
ALTER SEQUENCE IF EXISTS syncapi_receipt_id RESTART WITH 1;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_receipt_id');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Revert back to using the syncapi_stream_id sequence.
DROP SEQUENCE IF EXISTS syncapi_receipt_id;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_stream_id');
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -79,20 +79,20 @@ const insertEventSQL = "" +
"RETURNING id" "RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
@ -413,6 +413,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
@ -420,12 +421,12 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }

View file

@ -30,11 +30,12 @@ import (
) )
const receiptsSchema = ` const receiptsSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
-- Stores data about receipts -- Stores data about receipts
CREATE TABLE IF NOT EXISTS syncapi_receipts ( CREATE TABLE IF NOT EXISTS syncapi_receipts (
-- The ID -- The ID
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_receipt_id'),
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL, receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
@ -50,18 +51,22 @@ const upsertReceipt = "" +
" (room_id, receipt_type, user_id, event_id, receipt_ts)" + " (room_id, receipt_type, user_id, event_id, receipt_ts)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (room_id, receipt_type, user_id)" + " ON CONFLICT (room_id, receipt_type, user_id)" +
" DO UPDATE SET id = nextval('syncapi_stream_id'), event_id = $4, receipt_ts = $5" + " DO UPDATE SET id = nextval('syncapi_receipt_id'), event_id = $4, receipt_ts = $5" +
" RETURNING id" " RETURNING id"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE room_id = ANY($1) AND id > $2" " WHERE room_id = ANY($1) AND id > $2"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
} }
func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
@ -78,6 +83,9 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
} }
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil return r, nil
} }
@ -87,20 +95,37 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return return
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := streamPos
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
} }

View file

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -38,47 +37,36 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
const insertSendToDeviceMessageSQL = ` const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content) INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` RETURNING id
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
` `
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id = ANY($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id = ANY($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt selectMaxSendToDeviceIDStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt
} }
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -90,16 +78,13 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
@ -107,66 +92,60 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return continue
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
} }
if lastPos == 0 {
return events, rows.Err() lastPos = to
} }
return lastPos, events, rows.Err()
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
return
}
func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View file

@ -20,9 +20,9 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
) )
@ -36,6 +36,7 @@ type SyncServerDatasource struct {
} }
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
// nolint:gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
var d SyncServerDatasource var d SyncServerDatasource
var err error var err error
@ -86,6 +87,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,
@ -99,7 +106,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
Receipts: receipts, Receipts: receipts,
EDUCache: cache.New(),
} }
return &d, nil return &d, nil
} }

File diff suppressed because it is too large Load diff

View file

@ -46,6 +46,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users -- for querying membership states of users
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -64,7 +66,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2 IS NULL OR sender IN ($2) )" + " AND ( $2 IS NULL OR sender IN ($2) )" +
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" + " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
" AND ( $4 IS NULL OR type IN ($4) )" + " AND ( $4 IS NULL OR type IN ($4) )" +
@ -80,10 +82,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly five columns. We need to // the rowsToStreamEvents expects there to be exactly six columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)" " FROM syncapi_current_room_state WHERE event_id IN ($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
@ -289,13 +291,14 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{} result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var eventID string
var eventBytes []byte var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil { if err := rows.Scan(&eventID, &eventBytes); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }
result = append(result, &ev) result = append(result, &ev)

View file

@ -0,0 +1,59 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
UPDATE syncapi_stream_id SET stream_id=1 WHERE stream_name="receipt";
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,67 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL,
sent_by_token TEXT
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -56,20 +56,20 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $13" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $13"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
@ -428,6 +428,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
@ -435,12 +436,12 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }

View file

@ -51,15 +51,19 @@ const upsertReceipt = "" +
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE id > $1 and room_id in ($2)" " WHERE id > $1 and room_id in ($2)"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
} }
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
@ -77,12 +81,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
} }
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil return r, nil
} }
// UpsertReceipt creates new user receipts // UpsertReceipt creates new user receipts
func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
pos, err = r.streamIDStatements.nextStreamID(ctx, txn) pos, err = r.streamIDStatements.nextReceiptID(ctx, txn)
if err != nil { if err != nil {
return return
} }
@ -92,9 +99,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := streamPos
params := make([]interface{}, len(roomIDs)+1) params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos params[0] = streamPos
for k, v := range roomIDs { for k, v := range roomIDs {
@ -102,17 +109,33 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
} }
rows, err := r.db.QueryContext(ctx, selectSQL, params...) rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
} }

View file

@ -18,12 +18,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
) )
const sendToDeviceSchema = ` const sendToDeviceSchema = `
@ -36,11 +36,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -49,33 +45,27 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id IN ($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id IN ($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -86,91 +76,85 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) var result sql.Result
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
if p, err := result.LastInsertId(); err != nil {
return 0, err
} else {
pos = types.StreamPosition(p)
}
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
} continue
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
} }
if lastPos == 0 {
return events, rows.Err() lastPos = to
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
} }
_, err = txn.ExecContext(ctx, query, params...) return lastPos, events, rows.Err()
return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
params := make([]interface{}, 1+len(nids)) return
for k, v := range nids { }
params[k] = v
} func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
_, err = txn.ExecContext(ctx, query, params...) ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View file

@ -18,6 +18,8 @@ CREATE TABLE IF NOT EXISTS syncapi_stream_id (
); );
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
` `
const increaseStreamIDStmt = "" + const increaseStreamIDStmt = "" +
@ -56,3 +58,13 @@ func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}

View file

@ -21,10 +21,10 @@ import (
// Import the sqlite3 package // Import the sqlite3 package
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
) )
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -46,13 +46,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
return nil, err return nil, err
} }
d.writer = sqlutil.NewExclusiveWriter() d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil { if err = d.prepare(dbProperties); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) prepare() (err error) { // nolint:gocyclo
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err return err
} }
@ -99,6 +100,12 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,
@ -112,7 +119,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
Receipts: receipts, Receipts: receipts,
EDUCache: cache.New(),
} }
return nil return nil
} }

View file

@ -1,5 +1,7 @@
package storage_test package storage_test
// TODO: Fix these tests
/*
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
@ -165,9 +167,9 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync penultimate", Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are at the penultimate event from := types.StreamingToken{ // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-2],
) }
res := types.NewResponse() res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
}, },
@ -178,9 +180,9 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync limited", Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are 10 events behind from := types.StreamingToken{ // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-11],
) }
res := types.NewResponse() res := types.NewResponse()
// limit is set to 5 // limit is set to 5
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -222,8 +224,13 @@ func TestSyncResponse(t *testing.T) {
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil) next := types.StreamingToken{
if res.NextBatch != next.String() { PDUPosition: latest.PDUPosition,
TypingPosition: latest.TypingPosition,
ReceiptPosition: latest.ReceiptPosition,
SendToDevicePosition: latest.SendToDevicePosition,
}
if res.NextBatch.String() != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
} }
roomRes, ok := res.Rooms.Join[testRoomID] roomRes, ok := res.Rooms.Join[testRoomID]
@ -245,9 +252,9 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
from := types.NewStreamToken( from := types.StreamingToken{
positions[len(positions)-2], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-2],
) }
res := types.NewResponse() res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -261,7 +268,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
// returns the last event "Message 10" // returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch prev := roomRes.Timeline.PrevBatch.String()
if prev == "" { if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token") t.Fatalf("IncrementalSync expected prev_batch token")
} }
@ -271,7 +278,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
} }
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
@ -291,7 +298,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewStreamToken(0, 0, nil) to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
@ -313,7 +320,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
@ -382,7 +389,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
t.Fatalf("failed to get EventPositionInTopology for event: %s", err) t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
testCases := []struct { testCases := []struct {
Name string Name string
@ -458,7 +465,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths, // Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths,
// allowing this bug to surface. // allowing this bug to surface.
@ -508,7 +515,7 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// starting at `from`, backpaginate to the beginning of time, asserting as we go. // starting at `from`, backpaginate to the beginning of time, asserting as we go.
chunkSize = 3 chunkSize = 3
@ -534,20 +541,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil)) _, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates") t.Fatal("first call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
if err != nil { if err != nil {
return return
} }
// Try sending a message. // Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob", Sender: "bob",
Type: "m.type", Type: "m.type",
Content: json.RawMessage("{}"), Content: json.RawMessage("{}"),
@ -559,14 +566,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update") t.Fatal("second call should have one update")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
return return
} }
@ -574,35 +581,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still") t.Fatal("third call should have one update still")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
return return
} }
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates") t.Fatal("fourth call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil { if err != nil {
return return
} }
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -639,7 +646,7 @@ func TestInviteBehaviour(t *testing.T) {
} }
// both invite events should appear in a new sync // both invite events should appear in a new sync
beforeRetireRes := types.NewResponse() beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
@ -654,19 +661,15 @@ func TestInviteBehaviour(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
res := types.NewResponse() res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
assertInvitedToRooms(t, res, []string{inviteRoom2}) assertInvitedToRooms(t, res, []string{inviteRoom2})
// a sync after we have received both invites should result in a leave for the retired room // a sync after we have received both invites should result in a leave for the retired room
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
if err != nil {
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
}
res = types.NewResponse() res = types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireRes.NextBatch, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
@ -745,3 +748,4 @@ func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.Header
} }
return out return out
} }
*/

View file

@ -146,11 +146,10 @@ type BackwardsExtremities interface {
// sync parameter isn't later then we will keep including the updates in the // sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface { type Filter interface {
@ -160,5 +159,6 @@ type Filter interface {
type Receipts interface { type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }

View file

@ -0,0 +1,130 @@
package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type AccountDataStreamProvider struct {
StreamProvider
userAPI userapi.UserInternalAPI
}
func (p *AccountDataStreamProvider) Setup() {
p.StreamProvider.Setup()
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForAccountData(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *AccountDataStreamProvider) CompleteSync(
ctx context.Context,
req *types.SyncRequest,
) types.StreamPosition {
dataReq := &userapi.QueryAccountDataRequest{
UserID: req.Device.UserID,
}
dataRes := &userapi.QueryAccountDataResponse{}
if err := p.userAPI.QueryAccountData(ctx, dataReq, dataRes); err != nil {
req.Log.WithError(err).Error("p.userAPI.QueryAccountData failed")
return p.LatestPosition(ctx)
}
for datatype, databody := range dataRes.GlobalAccountData {
req.Response.AccountData.Events = append(
req.Response.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
}
for r, j := range req.Response.Rooms.Join {
for datatype, databody := range dataRes.RoomAccountData[r] {
j.AccountData.Events = append(
j.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
req.Response.Rooms.Join[r] = j
}
}
return p.LatestPosition(ctx)
}
func (p *AccountDataStreamProvider) IncrementalSync(
ctx context.Context,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
r := types.Range{
From: from,
To: to,
}
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
dataTypes, err := p.DB.GetAccountDataInRange(
ctx, req.Device.UserID, r, &accountDataFilter,
)
if err != nil {
req.Log.WithError(err).Error("p.DB.GetAccountDataInRange failed")
return from
}
// Iterate over the rooms
for roomID, dataTypes := range dataTypes {
// Request the missing data from the database
for _, dataType := range dataTypes {
dataReq := userapi.QueryAccountDataRequest{
UserID: req.Device.UserID,
RoomID: roomID,
DataType: dataType,
}
dataRes := userapi.QueryAccountDataResponse{}
err = p.userAPI.QueryAccountData(ctx, &dataReq, &dataRes)
if err != nil {
req.Log.WithError(err).Error("p.userAPI.QueryAccountData failed")
continue
}
if roomID == "" {
if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
req.Response.AccountData.Events = append(
req.Response.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: dataType,
Content: gomatrixserverlib.RawJSON(globalData),
},
)
}
} else {
if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
joinData := *types.NewJoinResponse()
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
joinData = existing
}
joinData.AccountData.Events = append(
joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: dataType,
Content: gomatrixserverlib.RawJSON(roomData),
},
)
req.Response.Rooms.Join[roomID] = joinData
}
}
}
}
return to
}

View file

@ -0,0 +1,43 @@
package streams
import (
"context"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/types"
)
type DeviceListStreamProvider struct {
PartitionedStreamProvider
rsAPI api.RoomserverInternalAPI
keyAPI keyapi.KeyInternalAPI
}
func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context,
req *types.SyncRequest,
) types.LogPosition {
return p.IncrementalSync(ctx, req, types.LogPosition{}, p.LatestPosition(ctx))
}
func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context,
req *types.SyncRequest,
from, to types.LogPosition,
) types.LogPosition {
var err error
to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
}
err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
}
return to
}

View file

@ -0,0 +1,64 @@
package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/types"
)
type InviteStreamProvider struct {
StreamProvider
}
func (p *InviteStreamProvider) Setup() {
p.StreamProvider.Setup()
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForInvites(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *InviteStreamProvider) CompleteSync(
ctx context.Context,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
}
func (p *InviteStreamProvider) IncrementalSync(
ctx context.Context,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
r := types.Range{
From: from,
To: to,
}
invites, retiredInvites, err := p.DB.InviteEventsInRange(
ctx, req.Device.UserID, r,
)
if err != nil {
req.Log.WithError(err).Error("p.DB.InviteEventsInRange failed")
return from
}
for roomID, inviteEvent := range invites {
ir := types.NewInviteResponse(inviteEvent)
req.Response.Rooms.Invite[roomID] = *ir
}
for roomID := range retiredInvites {
if _, ok := req.Response.Rooms.Join[roomID]; !ok {
lr := types.NewLeaveResponse()
req.Response.Rooms.Leave[roomID] = *lr
}
}
return to
}

Some files were not shown because too many files have changed in this diff Show more