Merge branch 'master' into matthew/peeking-over-fed

This commit is contained in:
Neil Alexander 2020-12-18 13:18:10 +00:00 committed by GitHub
commit 8508af345e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1059 additions and 603 deletions

View file

@ -185,6 +185,7 @@ linters:
- gocyclo - gocyclo
- goimports # Does everything gofmt does - goimports # Does everything gofmt does
- gosimple - gosimple
- govet
- ineffassign - ineffassign
- megacheck - megacheck
- misspell # Check code comments, whereas misspell in CI checks *.md files - misspell # Check code comments, whereas misspell in CI checks *.md files

View file

@ -19,4 +19,4 @@ fi
go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/... go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/...
GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs

View file

@ -77,7 +77,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -98,7 +98,7 @@ app_service_api:
connect: http://appservice_api:7777 connect: http://appservice_api:7777
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -173,7 +173,7 @@ federation_sender:
connect: http://federation_sender:7775 connect: http://federation_sender:7775
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_federationsender?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_federationsender?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -199,7 +199,7 @@ key_server:
connect: http://key_server:7779 connect: http://key_server:7779
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_keyserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_keyserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -212,7 +212,7 @@ media_api:
listen: http://0.0.0.0:8074 listen: http://0.0.0.0:8074
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_mediaapi?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_mediaapi?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -248,7 +248,7 @@ room_server:
connect: http://room_server:7770 connect: http://room_server:7770
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_roomserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_roomserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -259,7 +259,7 @@ signing_key_server:
connect: http://signing_key_server:7780 connect: http://signing_key_server:7780
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_signingkeyserver?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_signingkeyserver?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -288,7 +288,7 @@ sync_api:
listen: http://0.0.0.0:8073 listen: http://0.0.0.0:8073
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -299,12 +299,12 @@ user_api:
connect: http://user_api:7781 connect: http://user_api:7781
account_database: account_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -24,8 +24,6 @@ fi
echo "Installing golangci-lint..." echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first # Make a backup of go.{mod,sum} first
# TODO: Once go 1.13 is out, use go get's -mod=readonly option
# https://github.com/golang/go/issues/30667
cp go.mod go.mod.bak && cp go.sum go.sum.bak cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.19.1 go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.19.1

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -27,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,6 +42,25 @@ var (
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
) )
func init() {
prometheus.MustRegister(sendEventDuration)
}
var sendEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "clientapi",
Name: "sendevent_duration_millis",
Help: "How long it takes to build and submit a new event from the client API to the roomserver",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"action"},
)
// SendEvent implements: // SendEvent implements:
// /rooms/{roomID}/send/{eventType} // /rooms/{roomID}/send/{eventType}
// /rooms/{roomID}/send/{eventType}/{txnID} // /rooms/{roomID}/send/{eventType}/{txnID}
@ -75,10 +96,12 @@ func SendEvent(
mutex.(*sync.Mutex).Lock() mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
timeToGenerateEvent := time.Since(startedGeneratingEvent)
var txnAndSessionID *api.TransactionID var txnAndSessionID *api.TransactionID
if txnID != nil { if txnID != nil {
@ -90,6 +113,7 @@ func SendEvent(
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents( if err := api.SendEvents(
req.Context(), rsAPI, req.Context(), rsAPI,
api.KindNew, api.KindNew,
@ -102,6 +126,7 @@ func SendEvent(
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
timeToSubmitEvent := time.Since(startedSubmittingEvent)
util.GetLogger(req.Context()).WithFields(logrus.Fields{ util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": e.EventID(), "event_id": e.EventID(),
"room_id": roomID, "room_id": roomID,
@ -117,6 +142,11 @@ func SendEvent(
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, &res)
} }
// Take a note of how long it took to generate the event vs submit
// it to the roomserver.
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res return res
} }

View file

@ -65,6 +65,8 @@ func main() {
cfg.FederationSender.DisableTLSValidation = true cfg.FederationSender.DisableTLSValidation = true
cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.MSCs = []string{"msc2836"}
cfg.Logging[0].Level = "trace" cfg.Logging[0].Level = "trace"
// don't hit matrix.org when running tests!!!
cfg.SigningKeyServer.KeyPerspectives = config.KeyPerspectives{}
} }
j, err := yaml.Marshal(cfg) j, err := yaml.Marshal(cfg)

View file

@ -80,12 +80,6 @@ brew services start kafka
## Configuration ## Configuration
### SQLite database setup
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### PostgreSQL database setup ### PostgreSQL database setup
Assuming that PostgreSQL 9.6 (or later) is installed: Assuming that PostgreSQL 9.6 (or later) is installed:
@ -96,7 +90,23 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
sudo -u postgres createuser -P dendrite sudo -u postgres createuser -P dendrite
``` ```
* Create the component databases: At this point you have a choice on whether to run all of the Dendrite
components from a single database, or for each component to have its
own database. For most deployments, running from a single database will
be sufficient, although you may wish to separate them if you plan to
split out the databases across multiple machines in the future.
On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run all Dendrite components from a single database:
```bash
sudo -u postgres createdb -O dendrite dendrite
```
... in which case your connection string will look like `postgres://user:pass@database/dendrite`.
* If you want to run each Dendrite component with its own database:
```bash ```bash
for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do
@ -104,14 +114,22 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
done done
``` ```
(On macOS, omit `sudo -u postgres` from the above commands.) ... in which case your connection string will look like `postgres://user:pass@database/dendrite_componentname`.
### SQLite database setup
**WARNING:** SQLite is suitable for small experimental deployments only and should not be used in production - use PostgreSQL instead for any user-facing federating installation!
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### Server key generation ### Server key generation
Each Dendrite installation requires: Each Dendrite installation requires:
- A unique Matrix signing private key * A unique Matrix signing private key
- A valid and trusted TLS certificate and private key * A valid and trusted TLS certificate and private key
To generate a Matrix signing private key: To generate a Matrix signing private key:
@ -119,7 +137,7 @@ To generate a Matrix signing private key:
./bin/generate-keys --private-key matrix_key.pem ./bin/generate-keys --private-key matrix_key.pem
``` ```
**Warning:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or **WARNING:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or
any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining
federated rooms. federated rooms.
@ -129,8 +147,8 @@ For testing, you can generate a self-signed certificate and key, although this w
./bin/generate-keys --tls-cert server.crt --tls-key server.key ./bin/generate-keys --tls-cert server.crt --tls-key server.key
``` ```
If you have server keys from an older Synapse instance, If you have server keys from an older Synapse instance,
[convert them](serverkeyformat.md#converting-synapse-keys) to Dendrite's PEM [convert them](serverkeyformat.md#converting-synapse-keys) to Dendrite's PEM
format and configure them as `old_private_keys` in your config. format and configure them as `old_private_keys` in your config.
### Configuration file ### Configuration file
@ -140,7 +158,9 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th
* The `server_name` entry to reflect the hostname of your Dendrite server * The `server_name` entry to reflect the hostname of your Dendrite server
* The `database` lines with an updated connection string based on your * The `database` lines with an updated connection string based on your
desired setup, e.g. replacing `database` with the name of the database: desired setup, e.g. replacing `database` with the name of the database:
* For Postgres: `postgres://dendrite:password@localhost/database`, e.g. `postgres://dendrite:password@localhost/dendrite_userapi_account.db` * For Postgres: `postgres://dendrite:password@localhost/database`, e.g.
* `postgres://dendrite:password@localhost/dendrite_userapi_account` to connect to PostgreSQL with SSL/TLS
* `postgres://dendrite:password@localhost/dendrite_userapi_account?sslmode=disable` to connect to PostgreSQL without SSL/TLS
* For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db` * For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db`
* Postgres and SQLite can be mixed and matched on different components as desired. * Postgres and SQLite can be mixed and matched on different components as desired.
* The `use_naffka` option if using Naffka in a monolith deployment * The `use_naffka` option if using Naffka in a monolith deployment
@ -295,4 +315,3 @@ amongst other things.
```bash ```bash
./bin/dendrite-polylith-multi --config=dendrite.yaml userapi ./bin/dendrite-polylith-multi --config=dendrite.yaml userapi
``` ```

View file

@ -242,6 +242,8 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CAS(false, true) { if !oq.running.CAS(false, true) {
return return
} }
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.running.Store(false) defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database // Mark the queue as overflowed, so we will consult the database
@ -295,10 +297,14 @@ func (oq *destinationQueue) backgroundSend() {
// time. // time.
duration := time.Until(*until) duration := time.Until(*until)
log.Warnf("Backing off %q for %s", oq.destination, duration) log.Warnf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select { select {
case <-time.After(duration): case <-time.After(duration):
case <-oq.interruptBackoff: case <-oq.interruptBackoff:
} }
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
} }
// Work out which PDUs/EDUs to include in the next transaction. // Work out which PDUs/EDUs to include in the next transaction.

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@ -45,6 +46,37 @@ type OutgoingQueues struct {
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
func init() {
prometheus.MustRegister(
destinationQueueTotal, destinationQueueRunning,
destinationQueueBackingOff,
)
}
var destinationQueueTotal = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_total",
},
)
var destinationQueueRunning = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_running",
},
)
var destinationQueueBackingOff = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_backing_off",
},
)
// NewOutgoingQueues makes a new OutgoingQueues // NewOutgoingQueues makes a new OutgoingQueues
func NewOutgoingQueues( func NewOutgoingQueues(
db storage.Database, db storage.Database,
@ -116,6 +148,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
oq := oqs.queues[destination] oq := oqs.queues[destination]
if oq == nil { if oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{ oq = &destinationQueue{
db: oqs.db, db: oqs.db,
rsAPI: oqs.rsAPI, rsAPI: oqs.rsAPI,

View file

@ -0,0 +1,45 @@
package caching
import (
"github.com/matrix-org/dendrite/roomserver/types"
)
// WARNING: This cache is mutable because it's entirely possible that
// the IsStub or StateSnaphotNID fields can change, even though the
// room version and room NID fields will not. This is only safe because
// the RoomInfoCache is used ONLY within the roomserver and because it
// will be kept up-to-date by the latest events updater. It MUST NOT be
// used from other components as we currently have no way to invalidate
// the cache in downstream components.
const (
RoomInfoCacheName = "roominfo"
RoomInfoCacheMaxEntries = 1024
RoomInfoCacheMutable = true
)
// RoomInfosCache contains the subset of functions needed for
// a room Info cache. It must only be used from the roomserver only
// It is not safe for use from other components.
type RoomInfoCache interface {
GetRoomInfo(roomID string) (roomInfo types.RoomInfo, ok bool)
StoreRoomInfo(roomID string, roomInfo types.RoomInfo)
}
// GetRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) GetRoomInfo(roomID string) (types.RoomInfo, bool) {
val, found := c.RoomInfos.Get(roomID)
if found && val != nil {
if roomInfo, ok := val.(types.RoomInfo); ok {
return roomInfo, true
}
}
return types.RoomInfo{}, false
}
// StoreRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) StoreRoomInfo(roomID string, roomInfo types.RoomInfo) {
c.RoomInfos.Set(roomID, roomInfo)
}

View file

@ -15,10 +15,6 @@ const (
RoomServerEventTypeNIDsCacheMaxEntries = 64 RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomNIDsCacheName = "roomserver_room_nids"
RoomServerRoomNIDsCacheMaxEntries = 1024
RoomServerRoomNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false RoomServerRoomIDsCacheMutable = false
@ -27,6 +23,7 @@ const (
type RoomServerCaches interface { type RoomServerCaches interface {
RoomServerNIDsCache RoomServerNIDsCache
RoomVersionCache RoomVersionCache
RoomInfoCache
} }
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
@ -38,9 +35,6 @@ type RoomServerNIDsCache interface {
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomNID(roomID string) (types.RoomNID, bool)
StoreRoomServerRoomNID(roomID string, nid types.RoomNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
} }
@ -73,21 +67,6 @@ func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTyp
c.RoomServerEventTypeNIDs.Set(eventType, nid) c.RoomServerEventTypeNIDs.Set(eventType, nid)
} }
func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) {
val, found := c.RoomServerRoomNIDs.Get(roomID)
if found && val != nil {
if roomNID, ok := val.(types.RoomNID); ok {
return roomNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerRoomNID(roomID string, roomNID types.RoomNID) {
c.RoomServerRoomNIDs.Set(roomID, roomNID)
c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil { if found && val != nil {
@ -99,5 +78,5 @@ func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
} }
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
c.StoreRoomServerRoomNID(roomID, roomNID) c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
} }

View file

@ -10,6 +10,7 @@ type Caches struct {
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache FederationEvents Cache // FederationEventsCache
} }

View file

@ -45,19 +45,19 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerRoomNIDs, err := NewInMemoryLRUCachePartition( roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomNIDsCacheName, RoomServerRoomIDsCacheName,
RoomServerRoomNIDsCacheMutable, RoomServerRoomIDsCacheMutable,
RoomServerRoomNIDsCacheMaxEntries, RoomServerRoomIDsCacheMaxEntries,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerRoomIDs, err := NewInMemoryLRUCachePartition( roomInfos, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName, RoomInfoCacheName,
RoomServerRoomIDsCacheMutable, RoomInfoCacheMutable,
RoomServerRoomIDsCacheMaxEntries, RoomInfoCacheMaxEntries,
enablePrometheus, enablePrometheus,
) )
if err != nil { if err != nil {
@ -77,8 +77,8 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
ServerKeys: serverKeys, ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs, RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs, RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomNIDs: roomServerRoomNIDs,
RoomServerRoomIDs: roomServerRoomIDs, RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents, FederationEvents: federationEvents,
}, nil }, nil
} }

View file

@ -82,6 +82,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -83,6 +83,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -54,10 +54,8 @@ type inputWorker struct {
input chan *inputTask input chan *inputTask
} }
// Guarded by a CAS on w.running
func (w *inputWorker) start() { func (w *inputWorker) start() {
if !w.running.CAS(false, true) {
return
}
defer w.running.Store(false) defer w.running.Store(false)
for { for {
select { select {
@ -118,7 +116,7 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
ctx context.Context, _ context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) { ) {
@ -142,7 +140,7 @@ func (r *Inputer) InputRoomEvents(
// room - the channel will be quite small as it's just pointer types. // room - the channel will be quite small as it's just pointer types.
w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ w, _ := r.workers.LoadOrStore(roomID, &inputWorker{
r: r, r: r,
input: make(chan *inputTask, 10), input: make(chan *inputTask, 32),
}) })
worker := w.(*inputWorker) worker := w.(*inputWorker)
@ -150,13 +148,15 @@ func (r *Inputer) InputRoomEvents(
// the wait group, so that the worker can notify us when this specific // the wait group, so that the worker can notify us when this specific
// task has been finished. // task has been finished.
tasks[i] = &inputTask{ tasks[i] = &inputTask{
ctx: ctx, ctx: context.Background(),
event: &request.InputRoomEvents[i], event: &request.InputRoomEvents[i],
wg: wg, wg: wg,
} }
// Send the task to the worker. // Send the task to the worker.
go worker.start() if worker.running.CAS(false, true) {
go worker.start()
}
worker.input <- tasks[i] worker.input <- tasks[i]
} }

View file

@ -120,8 +120,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
const selectRoomNIDForEventNIDSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1" "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)"
type eventStatements struct { type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -137,7 +137,7 @@ type eventStatements struct {
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -161,7 +161,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -432,11 +432,24 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil return result, nil
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNID types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (roomNID types.RoomNID, err error) { ) (map[types.EventNID]types.RoomNID, error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
return if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
} }
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {

View file

@ -18,7 +18,6 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -69,8 +68,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" + const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
const selectRoomVersionForRoomNIDSQL = "" + const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid = ANY($1)"
const selectRoomInfoSQL = "" + const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -90,7 +89,7 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomVersionsForRoomNIDsStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt
bulkSelectRoomIDsStmt *sql.Stmt bulkSelectRoomIDsStmt *sql.Stmt
@ -109,7 +108,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
@ -219,15 +218,24 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) SelectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNIDs []types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err != nil {
if err == sql.ErrNoRows { return nil, err
return roomVersion, errors.New("room not found")
} }
return roomVersion, err defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
@ -271,3 +279,11 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []strin
} }
return roomNIDs, nil return roomNIDs, nil
} }
func roomNIDsAsArray(roomNIDs []types.RoomNID) pq.Int64Array {
nids := make([]int64, len(roomNIDs))
for i := range roomNIDs {
nids[i] = int64(roomNIDs[i])
}
return nids
}

View file

@ -105,6 +105,13 @@ func (u *LatestEventsUpdater) SetLatestEvents(
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
} }
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil return nil
}) })
} }

View file

@ -124,7 +124,15 @@ func (d *Database) StateEntriesForTuples(
} }
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.RoomsTable.SelectRoomInfo(ctx, roomID) if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
}
return roomInfo, err
} }
func (d *Database) AddState( func (d *Database) AddState(
@ -313,25 +321,39 @@ func (d *Database) Events(
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
results := make([]types.Event, len(eventJSONs)) var roomNIDs map[types.EventNID]types.RoomNID
for i, eventJSON := range eventJSONs { roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
var roomNID types.RoomNID if err != nil {
var roomVersion gomatrixserverlib.RoomVersion return nil, err
result := &results[i] }
result.EventNID = eventJSON.EventNID uniqueRoomNIDs := make(map[types.RoomNID]struct{})
roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID) for _, n := range roomNIDs {
if err != nil { uniqueRoomNIDs[n] = struct{}{}
return nil, err }
} roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
roomVersion, _ = d.Cache.GetRoomVersion(roomID) for n := range uniqueRoomNIDs {
} if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
if roomVersion == "" { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) roomVersions[n] = roomInfo.RoomVersion
if err != nil { continue
return nil, err
} }
} }
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
if err != nil {
return nil, err
}
for n, v := range dbRoomVersions {
roomVersions[n] = v
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID := roomNIDs[result.EventNID]
roomVersion := roomVersions[roomNID]
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
) )
@ -552,8 +574,8 @@ func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
if roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return roomNID, nil return roomInfo.RoomNID, nil
} }
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
@ -565,9 +587,6 @@ func (d *Database) assignRoomNID(
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
} }
} }
if err == nil {
d.Cache.StoreRoomServerRoomNID(roomID, roomNID)
}
return roomNID, err return roomNID, err
} }

View file

@ -95,8 +95,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDForEventNIDSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1" "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
type eventStatements struct { type eventStatements struct {
db *sql.DB db *sql.DB
@ -112,7 +112,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
@ -137,7 +137,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -480,11 +480,33 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil return result, nil
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNID types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (roomNID types.RoomNID, err error) { ) (map[types.EventNID]types.RoomNID, error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
return sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
} }
func eventNIDsAsArray(eventNIDs []types.EventNID) string { func eventNIDsAsArray(eventNIDs []types.EventNID) string {

View file

@ -19,7 +19,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
@ -60,8 +59,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" + const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
const selectRoomVersionForRoomNIDSQL = "" + const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomInfoSQL = "" + const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -82,9 +81,9 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt //selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt
} }
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -101,7 +100,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL},
}.Prepare(db) }.Prepare(db)
@ -223,15 +222,33 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) SelectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNIDs []types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) sqlPrep, err := s.db.Prepare(sqlStr)
if err == sql.ErrNoRows { if err != nil {
return roomVersion, errors.New("room not found") return nil, err
} }
return roomVersion, err iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {

View file

@ -10,8 +10,9 @@ import (
) )
type EventJSONPair struct { type EventJSONPair struct {
EventNID types.EventNID EventNID types.EventNID
EventJSON []byte RoomVersion gomatrixserverlib.RoomVersion
EventJSON []byte
} }
type EventJSON interface { type EventJSON interface {
@ -58,7 +59,7 @@ type Events interface {
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }
type Rooms interface { type Rooms interface {
@ -67,7 +68,7 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error) SelectRoomIDs(ctx context.Context) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)

View file

@ -92,7 +92,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil)) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -88,7 +88,7 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro
return err return err
} }
// update stream position // update stream position
s.notifier.OnNewReceipt(types.NewStreamToken(0, streamPos, nil)) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return nil return nil
} }

View file

@ -94,10 +94,8 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
"event_type": output.Type, "event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server") }).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice() streamPos, err := s.db.StoreNewSendForDeviceMessage(
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent,
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
) )
if err != nil { if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message") log.WithError(err).Errorf("failed to store send-to-device message")
@ -107,7 +105,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
s.notifier.OnNewSendToDevice( s.notifier.OnNewSendToDevice(
output.UserID, output.UserID,
[]string{output.DeviceID}, []string{output.DeviceID},
types.NewStreamToken(0, streamPos, nil), types.StreamingToken{SendToDevicePosition: streamPos},
) )
return nil return nil

View file

@ -64,10 +64,7 @@ func NewOutputTypingEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent( s.notifier.OnNewTyping(roomID, types.StreamingToken{TypingPosition: types.StreamPosition(latestSyncPosition)})
nil, roomID, nil,
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
)
}) })
return s.typingConsumer.Start() return s.typingConsumer.Start()
@ -95,6 +92,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil)) s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
return nil return nil
} }

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
syncinternal "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
syncapi "github.com/matrix-org/dendrite/syncapi/sync" syncapi "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -114,12 +113,12 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
return err return err
} }
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams // TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{ posUpdate := types.StreamingToken{
syncinternal.DeviceListLogName: { DeviceListPosition: types.LogPosition{
Offset: msg.Offset, Offset: msg.Offset,
Partition: msg.Partition, Partition: msg.Partition,
}, },
}) }
for userID := range queryRes.UserIDsToCount { for userID := range queryRes.UserIDsToCount {
s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID) s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
} }

View file

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return err return err
} }
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.notifier.OnNewEvent(ev, "", nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -220,7 +220,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
return err return err
} }
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.notifier.OnNewEvent(ev, "", nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -259,6 +259,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
if msg.Event.StateKey() == nil {
log.WithFields(log.Fields{
"event": string(msg.Event.JSON()),
}).Panicf("roomserver output log: invite has no state key")
return nil
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -269,14 +275,14 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, *msg.Event.StateKey())
return nil return nil
} }
func (s *OutputRoomEventConsumer) onRetireInviteEvent( func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent, ctx context.Context, msg api.OutputRetireInviteEvent,
) error { ) error {
sp, err := s.db.RetireInviteEvent(ctx, msg.EventID) pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -287,7 +293,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
} }
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs // Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil)) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
return nil return nil
} }
@ -307,7 +313,7 @@ func (s *OutputRoomEventConsumer) onNewPeek(
// we need to wake up the users who might need to now be peeking into this room, // we need to wake up the users who might need to now be peeking into this room,
// so we send in a dummy event to trigger a wakeup // so we send in a dummy event to trigger a wakeup
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil)) s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp})
return nil return nil
} }
@ -327,7 +333,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
// we need to wake up the users who might need to now be peeking into this room, // we need to wake up the users who might need to now be peeking into this room,
// so we send in a dummy event to trigger a wakeup // so we send in a dummy event to trigger a wakeup
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil)) s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp})
return nil return nil
} }

View file

@ -73,15 +73,13 @@ func DeviceListCatchup(
offset = sarama.OffsetOldest offset = sarama.OffsetOldest
// Extract partition/offset from sync token // Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := from.Log(DeviceListLogName) if !from.DeviceListPosition.IsEmpty() {
if logOffset != nil { partition = from.DeviceListPosition.Partition
partition = logOffset.Partition offset = from.DeviceListPosition.Offset
offset = logOffset.Offset
} }
var toOffset int64 var toOffset int64
toOffset = sarama.OffsetNewest toOffset = sarama.OffsetNewest
toLog := to.Log(DeviceListLogName) if toLog := to.DeviceListPosition; toLog.Partition == partition && toLog.Offset > 0 {
if toLog != nil && toLog.Offset > 0 {
toOffset = toLog.Offset toOffset = toLog.Offset
} }
var queryRes api.QueryKeyChangesResponse var queryRes api.QueryKeyChangesResponse
@ -130,11 +128,11 @@ func DeviceListCatchup(
} }
} }
// set the new token // set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{ to.DeviceListPosition = types.LogPosition{
Partition: queryRes.Partition, Partition: queryRes.Partition,
Offset: queryRes.Offset, Offset: queryRes.Offset,
}) }
res.NextBatch = to.String() res.NextBatch.ApplyUpdates(to)
return hasNew, nil return hasNew, nil
} }

View file

@ -16,13 +16,13 @@ import (
var ( var (
syncingUser = "@alice:localhost" syncingUser = "@alice:localhost"
emptyToken = types.NewStreamToken(0, 0, nil) emptyToken = types.StreamingToken{}
newestToken = types.NewStreamToken(0, 0, map[string]*types.LogPosition{ newestToken = types.StreamingToken{
DeviceListLogName: { DeviceListPosition: types.LogPosition{
Offset: sarama.OffsetNewest, Offset: sarama.OffsetNewest,
Partition: 0, Partition: 0,
}, },
}) }
) )
type mockKeyAPI struct{} type mockKeyAPI struct{}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,9 +50,10 @@ type messagesReq struct {
} }
type messagesResp struct { type messagesResp struct {
Start string `json:"start"` Start string `json:"start"`
End string `json:"end"` StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
} }
const defaultMessagesLimit = 10 const defaultMessagesLimit = 10
@ -65,6 +67,7 @@ func OnIncomingMessagesRequest(
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI, cfg *config.SyncAPI,
srp *sync.RequestPool,
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
@ -84,9 +87,18 @@ func OnIncomingMessagesRequest(
// Extract parameters from the request's URL. // Extract parameters from the request's URL.
// Pagination tokens. // Pagination tokens.
var fromStream *types.StreamingToken var fromStream *types.StreamingToken
from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from")) fromQuery := req.URL.Query().Get("from")
emptyFromSupplied := fromQuery == ""
if emptyFromSupplied {
// NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided.
// We do this to allow clients to get messages without having to call `/sync` e.g Cerulean
currPos := srp.Notifier.CurrentPosition()
fromQuery = currPos.String()
}
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil { if err != nil {
fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from")) fs, err2 := types.NewStreamTokenFromString(fromQuery)
fromStream = &fs fromStream = &fs
if err2 != nil { if err2 != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -185,14 +197,19 @@ func OnIncomingMessagesRequest(
"return_end": end.String(), "return_end": end.String(),
}).Info("Responding") }).Info("Responding")
res := messagesResp{
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
}
if emptyFromSupplied {
res.StartStream = fromStream.String()
}
// Respond with the events. // Respond with the events.
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: messagesResp{ JSON: res,
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
},
} }
} }
@ -381,7 +398,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate {
// We've hit the beginning of the room so there's really nowhere else // We've hit the beginning of the room so there's really nowhere else
// to go. This seems to fix Riot iOS from looping on /messages endlessly. // to go. This seems to fix Riot iOS from looping on /messages endlessly.
end = types.NewTopologyToken(0, 0) end = types.TopologyToken{}
} else { } else {
end, err = r.db.EventPositionInTopology( end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(), r.ctx, events[len(events)-1].EventID(),
@ -447,11 +464,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
// The condition in the SQL query is a strict "greater than" so // The condition in the SQL query is a strict "greater than" so
// we need to check against to-1. // we need to check against to-1.
streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos) isSetLargeEnough = (r.to.PDUPosition-1 == streamPos)
} }
} else { } else {
streamPos := types.StreamPosition(streamEvents[0].StreamPosition) streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos) isSetLargeEnough = (r.from.PDUPosition-1 == streamPos)
} }
} }
@ -565,7 +582,7 @@ func setToDefault(
if backwardOrdering { if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event // go 1 earlier than the first event so we correctly fetch the earliest event
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.NewTopologyToken(0, 0) to = types.TopologyToken{}
} else { } else {
to, err = db.MaxTopologicalPosition(ctx, roomID) to, err = db.MaxTopologicalPosition(ctx, roomID)
} }

View file

@ -51,7 +51,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter", r0mux.Handle("/user/{userId}/filter",

View file

@ -130,9 +130,9 @@ type Database interface {
// can be deleted altogether by CleanSendToDeviceUpdates // can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since" // The token supplied should be the current requested sync token, e.g. from the "since"
// parameter. // parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after

View file

@ -58,6 +58,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users -- for querying membership states of users
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +

View file

@ -0,0 +1,66 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Use the new syncapi_receipts_id sequence.
CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
ALTER SEQUENCE IF EXISTS syncapi_receipt_id RESTART WITH 1;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_receipt_id');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Revert back to using the syncapi_stream_id sequence.
DROP SEQUENCE IF EXISTS syncapi_receipt_id;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_stream_id');
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -30,11 +30,12 @@ import (
) )
const receiptsSchema = ` const receiptsSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
-- Stores data about receipts -- Stores data about receipts
CREATE TABLE IF NOT EXISTS syncapi_receipts ( CREATE TABLE IF NOT EXISTS syncapi_receipts (
-- The ID -- The ID
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_receipt_id'),
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL, receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
@ -50,18 +51,22 @@ const upsertReceipt = "" +
" (room_id, receipt_type, user_id, event_id, receipt_ts)" + " (room_id, receipt_type, user_id, event_id, receipt_ts)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (room_id, receipt_type, user_id)" + " ON CONFLICT (room_id, receipt_type, user_id)" +
" DO UPDATE SET id = nextval('syncapi_stream_id'), event_id = $4, receipt_ts = $5" + " DO UPDATE SET id = nextval('syncapi_receipt_id'), event_id = $4, receipt_ts = $5" +
" RETURNING id" " RETURNING id"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE room_id = ANY($1) AND id > $2" " WHERE room_id = ANY($1) AND id > $2"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
} }
func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
@ -78,6 +83,9 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
} }
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil return r, nil
} }
@ -87,20 +95,37 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return return
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := types.StreamPosition(0)
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
} }

View file

@ -49,6 +49,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
const insertSendToDeviceMessageSQL = ` const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content) INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
RETURNING id
` `
const countSendToDeviceMessagesSQL = ` const countSendToDeviceMessagesSQL = `
@ -107,8 +108,8 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
return return
} }
@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil { if err != nil {
return return
@ -152,9 +153,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
} }
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
return events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
) )
@ -36,6 +37,7 @@ type SyncServerDatasource struct {
} }
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
// nolint:gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
var d SyncServerDatasource var d SyncServerDatasource
var err error var err error
@ -86,6 +88,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -78,8 +78,8 @@ func (d *Database) GetEventsInStreamingRange(
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
r := types.Range{ r := types.Range{
From: from.PDUPosition(), From: from.PDUPosition,
To: to.PDUPosition(), To: to.PDUPosition,
Backwards: backwardOrdering, Backwards: backwardOrdering,
} }
if backwardOrdering { if backwardOrdering {
@ -391,16 +391,16 @@ func (d *Database) GetEventsInTopologicalRange(
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
if backwardOrdering { if backwardOrdering {
// Backward ordering means the 'from' token has a higher depth than the 'to' token // Backward ordering means the 'from' token has a higher depth than the 'to' token
minDepth = to.Depth() minDepth = to.Depth
maxDepth = from.Depth() maxDepth = from.Depth
// for cases where we have say 5 events with the same depth, the TopologyToken needs to // for cases where we have say 5 events with the same depth, the TopologyToken needs to
// know which of the 5 the client has seen. This is done by using the PDU position. // know which of the 5 the client has seen. This is done by using the PDU position.
// Events with the same maxDepth but less than this PDU position will be returned. // Events with the same maxDepth but less than this PDU position will be returned.
maxStreamPosForMaxDepth = from.PDUPosition() maxStreamPosForMaxDepth = from.PDUPosition
} else { } else {
// Forward ordering means the 'from' token has a lower depth than the 'to' token. // Forward ordering means the 'from' token has a lower depth than the 'to' token.
minDepth = from.Depth() minDepth = from.Depth
maxDepth = to.Depth() maxDepth = to.Depth
} }
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
@ -440,9 +440,9 @@ func (d *Database) MaxTopologicalPosition(
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
if err != nil { if err != nil {
return types.NewTopologyToken(0, 0), err return types.TopologyToken{}, err
} }
return types.NewTopologyToken(depth, streamPos), nil return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
} }
func (d *Database) EventPositionInTopology( func (d *Database) EventPositionInTopology(
@ -450,9 +450,9 @@ func (d *Database) EventPositionInTopology(
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID)
if err != nil { if err != nil {
return types.NewTopologyToken(0, 0), err return types.TopologyToken{}, err
} }
return types.NewTopologyToken(depth, stream), nil return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
} }
func (d *Database) syncPositionTx( func (d *Database) syncPositionTx(
@ -483,7 +483,17 @@ func (d *Database) syncPositionTx(
if maxPeekID > maxEventID { if maxPeekID > maxEventID {
maxEventID = maxPeekID maxEventID = maxPeekID
} }
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil) maxReceiptID, err := d.Receipts.SelectMaxReceiptID(ctx, txn)
if err != nil {
return sp, err
}
// TODO: complete these positions
sp = types.StreamingToken{
PDUPosition: types.StreamPosition(maxEventID),
TypingPosition: types.StreamPosition(d.EDUCache.GetLatestSyncPosition()),
ReceiptPosition: types.StreamPosition(maxReceiptID),
InvitePosition: types.StreamPosition(maxInviteID),
}
return return
} }
@ -534,11 +544,6 @@ func (d *Database) addPDUDeltaToResponse(
} }
} }
// TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
succeeded = true succeeded = true
return joinedRoomIDs, nil return joinedRoomIDs, nil
} }
@ -555,7 +560,7 @@ func (d *Database) addTypingDeltaToResponse(
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
var jr types.JoinResponse var jr types.JoinResponse
if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter( if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter(
roomID, int64(since.EDUPosition()), roomID, int64(since.TypingPosition),
); updated { ); updated {
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
@ -574,6 +579,7 @@ func (d *Database) addTypingDeltaToResponse(
res.Rooms.Join[roomID] = jr res.Rooms.Join[roomID] = jr
} }
} }
res.NextBatch.TypingPosition = types.StreamPosition(d.EDUCache.GetLatestSyncPosition())
return nil return nil
} }
@ -584,7 +590,7 @@ func (d *Database) addReceiptDeltaToResponse(
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition()) lastPos, receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.ReceiptPosition)
if err != nil { if err != nil {
return fmt.Errorf("unable to select receipts for rooms: %w", err) return fmt.Errorf("unable to select receipts for rooms: %w", err)
} }
@ -629,6 +635,7 @@ func (d *Database) addReceiptDeltaToResponse(
res.Rooms.Join[roomID] = jr res.Rooms.Join[roomID] = jr
} }
res.NextBatch.ReceiptPosition = lastPos
return nil return nil
} }
@ -639,7 +646,7 @@ func (d *Database) addEDUDeltaToResponse(
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
if fromPos.EDUPosition() != toPos.EDUPosition() { if fromPos.TypingPosition != toPos.TypingPosition {
// add typing deltas // add typing deltas
if err := d.addTypingDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { if err := d.addTypingDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil {
return fmt.Errorf("unable to apply typing delta to response: %w", err) return fmt.Errorf("unable to apply typing delta to response: %w", err)
@ -647,8 +654,8 @@ func (d *Database) addEDUDeltaToResponse(
} }
// Check on initial sync and if EDUPositions differ // Check on initial sync and if EDUPositions differ
if (fromPos.EDUPosition() == 0 && toPos.EDUPosition() == 0) || if (fromPos.ReceiptPosition == 0 && toPos.ReceiptPosition == 0) ||
fromPos.EDUPosition() != toPos.EDUPosition() { fromPos.ReceiptPosition != toPos.ReceiptPosition {
if err := d.addReceiptDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { if err := d.addReceiptDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil {
return fmt.Errorf("unable to apply receipts to response: %w", err) return fmt.Errorf("unable to apply receipts to response: %w", err)
} }
@ -682,15 +689,14 @@ func (d *Database) IncrementalSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos) res.NextBatch = fromPos.WithUpdates(toPos)
res.NextBatch = nextBatchPos.String()
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
r := types.Range{ r := types.Range{
From: fromPos.PDUPosition(), From: fromPos.PDUPosition,
To: toPos.PDUPosition(), To: toPos.PDUPosition,
} }
joinedRoomIDs, err = d.addPDUDeltaToResponse( joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, r, numRecentEventsPerRoom, wantFullState, res, ctx, device, r, numRecentEventsPerRoom, wantFullState, res,
@ -716,6 +722,14 @@ func (d *Database) IncrementalSync(
return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err) return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err)
} }
ir := types.Range{
From: fromPos.InvitePosition,
To: toPos.InvitePosition,
}
if err = d.addInvitesToResponse(ctx, nil, device.UserID, ir, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
return res, nil return res, nil
} }
@ -772,10 +786,14 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
} }
r := types.Range{ r := types.Range{
From: 0, From: 0,
To: toPos.PDUPosition(), To: toPos.PDUPosition,
}
ir := types.Range{
From: 0,
To: toPos.InvitePosition,
} }
res.NextBatch = toPos.String() res.NextBatch.ApplyUpdates(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -815,7 +833,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
} }
} }
if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, ir, res); err != nil {
return return
} }
@ -875,16 +893,18 @@ func (d *Database) getJoinResponseForCompleteSync(
// Retrieve the backward topology position, i.e. the position of the // Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
var prevBatchStr string var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 { if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil { if err != nil {
return return
} }
prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos) prevBatch = &types.TopologyToken{
Depth: backwardTopologyPos,
PDUPosition: backwardStreamPos,
}
prevBatch.Decrement() prevBatch.Decrement()
prevBatchStr = prevBatch.String()
} }
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
@ -893,7 +913,7 @@ func (d *Database) getJoinResponseForCompleteSync(
recentEvents := d.StreamEventsToEvents(&device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(&device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@ -915,7 +935,7 @@ func (d *Database) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res, types.StreamingToken{}, toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err) return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err)
@ -965,7 +985,7 @@ func (d *Database) getBackwardTopologyPos(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
events []types.StreamEvent, events []types.StreamEvent,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
zeroToken := types.NewTopologyToken(0, 0) zeroToken := types.TopologyToken{}
if len(events) == 0 { if len(events) == 0 {
return zeroToken, nil return zeroToken, nil
} }
@ -973,7 +993,7 @@ func (d *Database) getBackwardTopologyPos(
if err != nil { if err != nil {
return zeroToken, err return zeroToken, err
} }
tok := types.NewTopologyToken(pos, spos) tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
tok.Decrement() tok.Decrement()
return tok, nil return tok, nil
} }
@ -1021,7 +1041,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1029,7 +1049,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Peek: case gomatrixserverlib.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1040,7 +1060,7 @@ func (d *Database) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1361,39 +1381,40 @@ func (d *Database) SendToDeviceUpdatesWaiting(
} }
func (d *Database) StoreNewSendForDeviceMessage( func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) { ) (newPos types.StreamPosition, err error) {
j, err := json.Marshal(event) j, err := json.Marshal(event)
if err != nil { if err != nil {
return streamPos, err return 0, err
} }
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place. // that we don't lock the table for writes in more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SendToDevice.InsertSendToDeviceMessage( newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, string(j), ctx, txn, userID, deviceID, string(j),
) )
return err
}) })
if err != nil { if err != nil {
return streamPos, err return 0, err
} }
return streamPos, nil return 0, nil
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
token types.StreamingToken, token types.StreamingToken,
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
// If there's nothing to do then stop here. // If there's nothing to do then stop here.
if len(events) == 0 { if len(events) == 0 {
return nil, nil, nil, nil return 0, nil, nil, nil, nil
} }
// Work out whether we need to update any of the database entries. // Work out whether we need to update any of the database entries.
@ -1420,7 +1441,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
} }
} }
return toReturn, toUpdate, toDelete, nil return lastPos, toReturn, toUpdate, toDelete, nil
} }
func (d *Database) CleanSendToDeviceUpdates( func (d *Database) CleanSendToDeviceUpdates(
@ -1507,5 +1528,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
} }
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
return receipts, err
} }

View file

@ -46,6 +46,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users -- for querying membership states of users
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +

View file

@ -0,0 +1,58 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
UPDATE syncapi_stream_id SET stream_id=1 WHERE stream_name="receipt";
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -51,15 +51,19 @@ const upsertReceipt = "" +
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE id > $1 and room_id in ($2)" " WHERE id > $1 and room_id in ($2)"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
} }
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
@ -77,12 +81,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
} }
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil return r, nil
} }
// UpsertReceipt creates new user receipts // UpsertReceipt creates new user receipts
func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
pos, err = r.streamIDStatements.nextStreamID(ctx, txn) pos, err = r.streamIDStatements.nextReceiptID(ctx, txn)
if err != nil { if err != nil {
return return
} }
@ -92,9 +99,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := types.StreamPosition(0)
params := make([]interface{}, len(roomIDs)+1) params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos params[0] = streamPos
for k, v := range roomIDs { for k, v := range roomIDs {
@ -102,17 +109,33 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
} }
rows, err := r.db.QueryContext(ctx, selectSQL, params...) rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
} }

View file

@ -100,8 +100,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) var result sql.Result
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
if p, err := result.LastInsertId(); err != nil {
return 0, err
} else {
pos = types.StreamPosition(p)
}
return return
} }
@ -117,7 +123,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil { if err != nil {
return return
@ -145,9 +151,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
} }
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
return events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -18,6 +18,8 @@ CREATE TABLE IF NOT EXISTS syncapi_stream_id (
); );
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
` `
const increaseStreamIDStmt = "" + const increaseStreamIDStmt = "" +
@ -56,3 +58,13 @@ func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
) )
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -46,13 +47,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
return nil, err return nil, err
} }
d.writer = sqlutil.NewExclusiveWriter() d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil { if err = d.prepare(dbProperties); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) prepare() (err error) { // nolint:gocyclo
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err return err
} }
@ -99,6 +101,11 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -165,9 +165,9 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync penultimate", Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are at the penultimate event from := types.StreamingToken{ // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-2],
) }
res := types.NewResponse() res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
}, },
@ -178,9 +178,9 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync limited", Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are 10 events behind from := types.StreamingToken{ // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-11],
) }
res := types.NewResponse() res := types.NewResponse()
// limit is set to 5 // limit is set to 5
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -222,8 +222,13 @@ func TestSyncResponse(t *testing.T) {
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil) next := types.StreamingToken{
if res.NextBatch != next.String() { PDUPosition: latest.PDUPosition,
TypingPosition: latest.TypingPosition,
ReceiptPosition: latest.ReceiptPosition,
SendToDevicePosition: latest.SendToDevicePosition,
}
if res.NextBatch.String() != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
} }
roomRes, ok := res.Rooms.Join[testRoomID] roomRes, ok := res.Rooms.Join[testRoomID]
@ -245,9 +250,9 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
from := types.NewStreamToken( from := types.StreamingToken{
positions[len(positions)-2], types.StreamPosition(0), nil, PDUPosition: positions[len(positions)-2],
) }
res := types.NewResponse() res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -261,7 +266,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
// returns the last event "Message 10" // returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch prev := roomRes.Timeline.PrevBatch.String()
if prev == "" { if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token") t.Fatalf("IncrementalSync expected prev_batch token")
} }
@ -271,7 +276,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
} }
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
@ -291,7 +296,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewStreamToken(0, 0, nil) to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
@ -313,7 +318,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
@ -382,7 +387,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
t.Fatalf("failed to get EventPositionInTopology for event: %s", err) t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
testCases := []struct { testCases := []struct {
Name string Name string
@ -458,7 +463,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths, // Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths,
// allowing this bug to surface. // allowing this bug to surface.
@ -508,7 +513,7 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewTopologyToken(0, 0) to := types.TopologyToken{}
// starting at `from`, backpaginate to the beginning of time, asserting as we go. // starting at `from`, backpaginate to the beginning of time, asserting as we go.
chunkSize = 3 chunkSize = 3
@ -534,20 +539,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil)) _, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates") t.Fatal("first call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
if err != nil { if err != nil {
return return
} }
// Try sending a message. // Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob", Sender: "bob",
Type: "m.type", Type: "m.type",
Content: json.RawMessage("{}"), Content: json.RawMessage("{}"),
@ -559,14 +564,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update") t.Fatal("second call should have one update")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
return return
} }
@ -574,35 +579,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still") t.Fatal("third call should have one update still")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
return return
} }
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates") t.Fatal("fourth call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil { if err != nil {
return return
} }
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil)) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -639,7 +644,7 @@ func TestInviteBehaviour(t *testing.T) {
} }
// both invite events should appear in a new sync // both invite events should appear in a new sync
beforeRetireRes := types.NewResponse() beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
@ -654,19 +659,15 @@ func TestInviteBehaviour(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
res := types.NewResponse() res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
assertInvitedToRooms(t, res, []string{inviteRoom2}) assertInvitedToRooms(t, res, []string{inviteRoom2})
// a sync after we have received both invites should result in a leave for the retired room // a sync after we have received both invites should result in a leave for the retired room
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
if err != nil {
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
}
res = types.NewResponse() res = types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireRes.NextBatch, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }

View file

@ -146,8 +146,8 @@ type BackwardsExtremities interface {
// sync parameter isn't later then we will keep including the updates in the // sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
@ -160,5 +160,6 @@ type Filter interface {
type Receipts interface { type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }

View file

@ -77,9 +77,8 @@ func (n *Notifier) OnNewEvent(
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.currPos.ApplyUpdates(posUpdate)
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
if ev != nil { if ev != nil {
@ -113,11 +112,11 @@ func (n *Notifier) OnNewEvent(
} }
} }
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos) n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
} else if roomID != "" { } else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos) n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
} else if len(userIDs) > 0 { } else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, nil, latestPos) n.wakeupUsers(userIDs, nil, n.currPos)
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"posUpdate": posUpdate.String, "posUpdate": posUpdate.String,
@ -155,20 +154,33 @@ func (n *Notifier) OnNewSendToDevice(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos) n.currPos.ApplyUpdates(posUpdate)
n.wakeupUserDevice(userID, deviceIDs, n.currPos)
} }
// OnNewReceipt updates the current position // OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt( func (n *Notifier) OnNewTyping(
roomID string,
posUpdate types.StreamingToken, posUpdate types.StreamingToken,
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
}
// OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt(
roomID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
} }
func (n *Notifier) OnNewKeyChange( func (n *Notifier) OnNewKeyChange(
@ -176,9 +188,19 @@ func (n *Notifier) OnNewKeyChange(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, latestPos) n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
func (n *Notifier) OnNewInvite(
posUpdate types.StreamingToken, wakeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
} }
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for

View file

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld = types.NewStreamToken(5, 0, nil) syncPositionVeryOld = types.StreamingToken{PDUPosition: 5}
syncPositionBefore = types.NewStreamToken(11, 0, nil) syncPositionBefore = types.StreamingToken{PDUPosition: 11}
syncPositionAfter = types.NewStreamToken(12, 0, nil) syncPositionAfter = types.StreamingToken{PDUPosition: 12}
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil) //syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition, 1, 0, 0, nil)
syncPositionAfter2 = types.NewStreamToken(13, 0, nil) syncPositionAfter2 = types.StreamingToken{PDUPosition: 13}
) )
var ( var (
@ -205,6 +205,9 @@ func TestNewInviteEventForUser(t *testing.T) {
} }
// Test an EDU-only update wakes up the request. // Test an EDU-only update wakes up the request.
// TODO: Fix this test, invites wake up with an incremented
// PDU position, not EDU position
/*
func TestEDUWakeup(t *testing.T) { func TestEDUWakeup(t *testing.T) {
n := NewNotifier(syncPositionAfter) n := NewNotifier(syncPositionAfter)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
@ -229,6 +232,7 @@ func TestEDUWakeup(t *testing.T) {
wg.Wait() wg.Wait()
} }
*/
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
@ -331,7 +335,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
return types.StreamingToken{}, fmt.Errorf( return types.StreamingToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(req.since):
p := listener.GetSyncPosition() p := listener.GetSyncPosition()
return p, nil return p, nil
} }
@ -361,7 +365,7 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syn
ID: deviceID, ID: deviceID,
}, },
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: since,
wantFullState: false, wantFullState: false,
limit: DefaultTimelineLimit, limit: DefaultTimelineLimit,
log: util.GetLogger(context.TODO()), log: util.GetLogger(context.TODO()),

View file

@ -46,7 +46,7 @@ type syncRequest struct {
device userapi.Device device userapi.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamingToken // nil means that no since token was supplied since types.StreamingToken // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -55,18 +55,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
var since *types.StreamingToken since, sinceStr := types.StreamingToken{}, req.URL.Query().Get("since")
sinceStr := req.URL.Query().Get("since")
if sinceStr != "" { if sinceStr != "" {
tok, err := types.NewStreamTokenFromString(sinceStr) var err error
since, err = types.NewStreamTokenFromString(sinceStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
since = &tok
}
if since == nil {
tok := types.NewStreamToken(0, 0, nil)
since = &tok
} }
timelineLimit := DefaultTimelineLimit timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too // TODO: read from stored filters too

View file

@ -35,6 +35,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -43,7 +44,7 @@ type RequestPool struct {
db storage.Database db storage.Database
cfg *config.SyncAPI cfg *config.SyncAPI
userAPI userapi.UserInternalAPI userAPI userapi.UserInternalAPI
notifier *Notifier Notifier *Notifier
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI
lastseen sync.Map lastseen sync.Map
@ -99,6 +100,30 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device)
rp.lastseen.Store(device.UserID+device.ID, time.Now()) rp.lastseen.Store(device.UserID+device.ID, time.Now())
} }
func init() {
prometheus.MustRegister(
activeSyncRequests, waitingSyncRequests,
)
}
var activeSyncRequests = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "active_sync_requests",
Help: "The number of sync requests that are active right now",
},
)
var waitingSyncRequests = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "waiting_sync_requests",
Help: "The number of sync requests that are waiting to be woken by a notifier",
},
)
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
// called in a dedicated goroutine for this request. This function will block the goroutine // called in a dedicated goroutine for this request. This function will block the goroutine
// until a response is ready, or it times out. // until a response is ready, or it times out.
@ -122,9 +147,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
"limit": syncReq.limit, "limit": syncReq.limit,
}) })
activeSyncRequests.Inc()
defer activeSyncRequests.Dec()
rp.updateLastSeen(req, device) rp.updateLastSeen(req, device)
currPos := rp.notifier.CurrentPosition() currPos := rp.Notifier.CurrentPosition()
if rp.shouldReturnImmediately(syncReq) { if rp.shouldReturnImmediately(syncReq) {
syncData, err = rp.currentSyncForUser(*syncReq, currPos) syncData, err = rp.currentSyncForUser(*syncReq, currPos)
@ -139,13 +167,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
} }
} }
waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec()
// Otherwise, we wait for the notifier to tell us if something *may* have // Otherwise, we wait for the notifier to tell us if something *may* have
// happened. We loop in case it turns out that nothing did happen. // happened. We loop in case it turns out that nothing did happen.
timer := time.NewTimer(syncReq.timeout) // case of timeout=0 is handled above timer := time.NewTimer(syncReq.timeout) // case of timeout=0 is handled above
defer timer.Stop() defer timer.Stop()
userStreamListener := rp.notifier.GetListener(*syncReq) userStreamListener := rp.Notifier.GetListener(*syncReq)
defer userStreamListener.Close() defer userStreamListener.Close()
// We need the loop in case userStreamListener wakes up even if there isn't // We need the loop in case userStreamListener wakes up even if there isn't
@ -154,13 +185,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// respond with, so we skip the return an go back to waiting for content to // respond with, so we skip the return an go back to waiting for content to
// be sent down or the request timing out. // be sent down or the request timing out.
var hasTimedOut bool var hasTimedOut bool
sincePos := *syncReq.since sincePos := syncReq.since
for { for {
select { select {
// Wait for notifier to wake us up // Wait for notifier to wake us up
case <-userStreamListener.GetNotifyChannel(sincePos): case <-userStreamListener.GetNotifyChannel(sincePos):
currPos = userStreamListener.GetSyncPosition() currPos = userStreamListener.GetSyncPosition()
sincePos = currPos
// Or for timeout to expire // Or for timeout to expire
case <-timer.C: case <-timer.C:
// We just need to ensure we get out of the select after reaching the // We just need to ensure we get out of the select after reaching the
@ -248,30 +278,30 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
res := types.NewResponse() res := types.NewResponse()
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since) lastPos, events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, req.since)
if err != nil { if err != nil {
return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err) return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err)
} }
// TODO: handle ignored users // TODO: handle ignored users
if req.since.PDUPosition() == 0 && req.since.EDUPosition() == 0 { if req.since.IsEmpty() {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.db.CompleteSync: %w", err) return res, fmt.Errorf("rp.db.CompleteSync: %w", err)
} }
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) res, err = rp.db.IncrementalSync(req.ctx, res, req.device, req.since, latestPos, req.limit, req.wantFullState)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.db.IncrementalSync: %w", err) return res, fmt.Errorf("rp.db.IncrementalSync: %w", err)
} }
} }
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.appendAccountData: %w", err) return res, fmt.Errorf("rp.appendAccountData: %w", err)
} }
res, err = rp.appendDeviceLists(res, req.device.UserID, *req.since, latestPos) res, err = rp.appendDeviceLists(res, req.device.UserID, req.since, latestPos)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.appendDeviceLists: %w", err) return res, fmt.Errorf("rp.appendDeviceLists: %w", err)
} }
@ -285,7 +315,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
// Then add the updates into the sync response. // Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 { if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database. // Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since) err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.since)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err) return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err)
} }
@ -295,15 +325,9 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
for _, event := range events { for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
} }
// Get the next_batch from the sync response and increase the
// EDU counter.
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.Positions[1]++
res.NextBatch = pos.String()
}
} }
res.NextBatch.SendToDevicePosition = lastPos
return res, err return res, err
} }
@ -328,7 +352,7 @@ func (rp *RequestPool) appendAccountData(
// data keys were set between two message. This isn't a huge issue since the // data keys were set between two message. This isn't a huge issue since the
// duplicate data doesn't represent a huge quantity of data, but an optimisation // duplicate data doesn't represent a huge quantity of data, but an optimisation
// here would be making sure each data is sent only once to the client. // here would be making sure each data is sent only once to the client.
if req.since == nil || (req.since.PDUPosition() == 0 && req.since.EDUPosition() == 0) { if req.since.IsEmpty() {
// If this is the initial sync, we don't need to check if a data has // If this is the initial sync, we don't need to check if a data has
// already been sent. Instead, we send the whole batch. // already been sent. Instead, we send the whole batch.
dataReq := &userapi.QueryAccountDataRequest{ dataReq := &userapi.QueryAccountDataRequest{
@ -363,7 +387,7 @@ func (rp *RequestPool) appendAccountData(
} }
r := types.Range{ r := types.Range{
From: req.since.PDUPosition(), From: req.since.PDUPosition,
To: currentPos, To: currentPos,
} }
// If both positions are the same, it means that the data was saved after the // If both positions are the same, it means that the data was saved after the
@ -433,7 +457,7 @@ func (rp *RequestPool) appendAccountData(
// or timeout=0, or full_state=true, in any of the cases the request should // or timeout=0, or full_state=true, in any of the cases the request should
// return immediately. // return immediately.
func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool { func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState { if syncReq.since.IsEmpty() || syncReq.timeout == 0 || syncReq.wantFullState {
return true return true
} }
waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID) waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)

View file

@ -16,9 +16,7 @@ package types
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sort"
"strconv" "strconv"
"strings" "strings"
@ -46,6 +44,10 @@ type LogPosition struct {
Offset int64 Offset int64
} }
func (p *LogPosition) IsEmpty() bool {
return p.Offset == 0
}
// IsAfter returns true if this position is after `lp`. // IsAfter returns true if this position is after `lp`.
func (p *LogPosition) IsAfter(lp *LogPosition) bool { func (p *LogPosition) IsAfter(lp *LogPosition) bool {
if lp == nil { if lp == nil {
@ -107,108 +109,125 @@ const (
) )
type StreamingToken struct { type StreamingToken struct {
syncToken PDUPosition StreamPosition
logs map[string]*LogPosition TypingPosition StreamPosition
ReceiptPosition StreamPosition
SendToDevicePosition StreamPosition
InvitePosition StreamPosition
DeviceListPosition LogPosition
} }
func (t *StreamingToken) SetLog(name string, lp *LogPosition) { // This will be used as a fallback by json.Marshal.
if t.logs == nil { func (s StreamingToken) MarshalText() ([]byte, error) {
t.logs = make(map[string]*LogPosition) return []byte(s.String()), nil
}
t.logs[name] = lp
} }
func (t *StreamingToken) Log(name string) *LogPosition { // This will be used as a fallback by json.Unmarshal.
l, ok := t.logs[name] func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
if !ok { *s, err = NewStreamTokenFromString(string(text))
return nil return err
}
return l
} }
func (t *StreamingToken) PDUPosition() StreamPosition { func (t StreamingToken) String() string {
return t.Positions[0] posStr := fmt.Sprintf(
} "s%d_%d_%d_%d_%d",
func (t *StreamingToken) EDUPosition() StreamPosition { t.PDUPosition, t.TypingPosition,
return t.Positions[1] t.ReceiptPosition, t.SendToDevicePosition,
} t.InvitePosition,
func (t *StreamingToken) String() string { )
var logStrings []string if dl := t.DeviceListPosition; !dl.IsEmpty() {
for name, lp := range t.logs { posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset)
logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset)
logStrings = append(logStrings, logStr)
} }
sort.Strings(logStrings) return posStr
// E.g s11_22_33.dl0-134.ab1-441
return strings.Join(append([]string{t.syncToken.String()}, logStrings...), ".")
} }
// IsAfter returns true if ANY position in this token is greater than `other`. // IsAfter returns true if ANY position in this token is greater than `other`.
func (t *StreamingToken) IsAfter(other StreamingToken) bool { func (t *StreamingToken) IsAfter(other StreamingToken) bool {
for i := range other.Positions { switch {
if t.Positions[i] > other.Positions[i] { case t.PDUPosition > other.PDUPosition:
return true return true
} case t.TypingPosition > other.TypingPosition:
} return true
for name := range t.logs { case t.ReceiptPosition > other.ReceiptPosition:
otherLog := other.Log(name) return true
if otherLog == nil { case t.SendToDevicePosition > other.SendToDevicePosition:
continue return true
} case t.InvitePosition > other.InvitePosition:
if t.logs[name].IsAfter(otherLog) { return true
return true case t.DeviceListPosition.IsAfter(&other.DeviceListPosition):
} return true
} }
return false return false
} }
func (t *StreamingToken) IsEmpty() bool {
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition == 0 && t.DeviceListPosition.IsEmpty()
}
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// If the latter StreamingToken contains a field that is not 0, it is considered an update, // If the latter StreamingToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. // and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
// If the other token has a log, they will replace any existing log on this token. // If the other token has a log, they will replace any existing log on this token.
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { func (t *StreamingToken) WithUpdates(other StreamingToken) StreamingToken {
ret.Type = t.Type ret := *t
ret.Positions = make([]StreamPosition, len(t.Positions)) ret.ApplyUpdates(other)
for i := range t.Positions {
ret.Positions[i] = t.Positions[i]
if other.Positions[i] == 0 {
continue
}
ret.Positions[i] = other.Positions[i]
}
ret.logs = make(map[string]*LogPosition)
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
copy := *otherLog
ret.logs[name] = &copy
}
return ret return ret
} }
type TopologyToken struct { // ApplyUpdates applies any changes from the supplied StreamingToken. If the supplied
syncToken // streaming token contains any positions that are not 0, they are considered updates
// and will overwrite the value in the token.
func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.PDUPosition > 0 {
t.PDUPosition = other.PDUPosition
}
if other.TypingPosition > 0 {
t.TypingPosition = other.TypingPosition
}
if other.ReceiptPosition > 0 {
t.ReceiptPosition = other.ReceiptPosition
}
if other.SendToDevicePosition > 0 {
t.SendToDevicePosition = other.SendToDevicePosition
}
if other.InvitePosition > 0 {
t.InvitePosition = other.InvitePosition
}
if other.DeviceListPosition.Offset > 0 {
t.DeviceListPosition = other.DeviceListPosition
}
} }
func (t *TopologyToken) Depth() StreamPosition { type TopologyToken struct {
return t.Positions[0] Depth StreamPosition
PDUPosition StreamPosition
} }
func (t *TopologyToken) PDUPosition() StreamPosition {
return t.Positions[1] // This will be used as a fallback by json.Marshal.
func (t TopologyToken) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
} }
// This will be used as a fallback by json.Unmarshal.
func (t *TopologyToken) UnmarshalText(text []byte) (err error) {
*t, err = NewTopologyTokenFromString(string(text))
return err
}
func (t *TopologyToken) StreamToken() StreamingToken { func (t *TopologyToken) StreamToken() StreamingToken {
return NewStreamToken(t.PDUPosition(), 0, nil) return StreamingToken{
PDUPosition: t.PDUPosition,
}
} }
func (t *TopologyToken) String() string {
return t.syncToken.String() func (t TopologyToken) String() string {
return fmt.Sprintf("t%d_%d", t.Depth, t.PDUPosition)
} }
// Decrement the topology token to one event earlier. // Decrement the topology token to one event earlier.
func (t *TopologyToken) Decrement() { func (t *TopologyToken) Decrement() {
depth := t.Positions[0] depth := t.Depth
pduPos := t.Positions[1] pduPos := t.PDUPosition
if depth-1 <= 0 { if depth-1 <= 0 {
// nothing can be lower than this // nothing can be lower than this
depth = 1 depth = 1
@ -223,151 +242,95 @@ func (t *TopologyToken) Decrement() {
if depth < 1 { if depth < 1 {
depth = 1 depth = 1
} }
t.Positions = []StreamPosition{ t.Depth = depth
depth, pduPos, t.PDUPosition = pduPos
}
} }
// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x" func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
// represents the type of a pagination token and "yyyy..." the token itself, and if len(tok) < 1 {
// parses it in order to create a new instance of SyncToken. Returns an err = fmt.Errorf("empty topology token")
// error if the token couldn't be parsed into an int64, or if the token type return
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
// case).
func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) {
if len(s) == 0 {
return nil, nil, ErrInvalidSyncTokenLen
} }
if tok[0] != SyncTokenTypeTopology[0] {
token = new(syncToken) err = fmt.Errorf("topology token must start with 't'")
var positions []string return
switch t := SyncTokenType(s[:1]); t {
case SyncTokenTypeStream, SyncTokenTypeTopology:
token.Type = t
categories = strings.Split(s[1:], ".")
positions = strings.Split(categories[0], "_")
default:
return nil, nil, ErrInvalidSyncTokenType
} }
parts := strings.Split(tok[1:], "_")
for _, pos := range positions { var positions [2]StreamPosition
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { for i, p := range parts {
return nil, nil, err if i > len(positions) {
} else if posInt < 0 { break
return nil, nil, errors.New("negative position not allowed")
} else {
token.Positions = append(token.Positions, StreamPosition(posInt))
} }
var pos int
pos, err = strconv.Atoi(p)
if err != nil {
return
}
positions[i] = StreamPosition(pos)
}
token = TopologyToken{
Depth: positions[0],
PDUPosition: positions[1],
} }
return return
} }
// NewTopologyToken creates a new sync token for /messages
func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
if depth < 0 {
depth = 1
}
return TopologyToken{
syncToken: syncToken{
Type: SyncTokenTypeTopology,
Positions: []StreamPosition{depth, streamPos},
},
}
}
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
t, _, err := newSyncTokenFromString(tok)
if err != nil {
return
}
if t.Type != SyncTokenTypeTopology {
err = fmt.Errorf("token %s is not a topology token", tok)
return
}
if len(t.Positions) < 2 {
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
return
}
return TopologyToken{
syncToken: *t,
}, nil
}
// NewStreamToken creates a new sync token for /sync
func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken {
if logs == nil {
logs = make(map[string]*LogPosition)
}
return StreamingToken{
syncToken: syncToken{
Type: SyncTokenTypeStream,
Positions: []StreamPosition{pduPos, eduPos},
},
logs: logs,
}
}
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
t, categories, err := newSyncTokenFromString(tok) if len(tok) < 1 {
if err != nil { err = fmt.Errorf("empty stream token")
return return
} }
if t.Type != SyncTokenTypeStream { if tok[0] != SyncTokenTypeStream[0] {
err = fmt.Errorf("token %s is not a streaming token", tok) err = fmt.Errorf("stream token must start with 's'")
return return
} }
if len(t.Positions) < 2 { categories := strings.Split(tok[1:], ".")
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions)) parts := strings.Split(categories[0], "_")
return var positions [5]StreamPosition
for i, p := range parts {
if i > len(positions) {
break
}
var pos int
pos, err = strconv.Atoi(p)
if err != nil {
return
}
positions[i] = StreamPosition(pos)
} }
logs := make(map[string]*LogPosition) token = StreamingToken{
if len(categories) > 1 { PDUPosition: positions[0],
// dl-0-1234 TypingPosition: positions[1],
// $log_name-$partition-$offset ReceiptPosition: positions[2],
for _, logStr := range categories[1:] { SendToDevicePosition: positions[3],
segments := strings.Split(logStr, "-") InvitePosition: positions[4],
if len(segments) != 3 { }
err = fmt.Errorf("token %s - invalid log: %s", tok, logStr) // dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("invalid log position %q", logStr)
return
}
switch segments[0] {
case "dl":
// Device list syncing
var partition, offset int
if partition, err = strconv.Atoi(segments[1]); err != nil {
return return
} }
var partition int64 if offset, err = strconv.Atoi(segments[2]); err != nil {
partition, err = strconv.ParseInt(segments[1], 10, 32)
if err != nil {
return return
} }
var offset int64 token.DeviceListPosition.Partition = int32(partition)
offset, err = strconv.ParseInt(segments[2], 10, 64) token.DeviceListPosition.Offset = int64(offset)
if err != nil { default:
return err = fmt.Errorf("unrecognised token type %q", segments[0])
} return
logs[segments[0]] = &LogPosition{
Partition: int32(partition),
Offset: offset,
}
} }
} }
return StreamingToken{ return token, nil
syncToken: *t,
logs: logs,
}, nil
}
// syncToken represents a syncapi token, used for interactions with
// /sync or /messages, for example.
type syncToken struct {
Type SyncTokenType
// A list of stream positions, their meanings vary depending on the token type.
Positions []StreamPosition
}
// String translates a SyncToken to a string of the "xyyyy..." (see
// NewSyncToken to know what it represents).
func (p *syncToken) String() string {
posStr := make([]string, len(p.Positions))
for i := range p.Positions {
posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10)
}
return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_"))
} }
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -379,7 +342,7 @@ type PrevEventRef struct {
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync // Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct { type Response struct {
NextBatch string `json:"next_batch"` NextBatch StreamingToken `json:"next_batch"`
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data,omitempty"` } `json:"account_data,omitempty"`
@ -443,7 +406,7 @@ type JoinResponse struct {
Timeline struct { Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"` Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"` PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"` } `json:"timeline"`
Ephemeral struct { Ephemeral struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -501,7 +464,7 @@ type LeaveResponse struct {
Timeline struct { Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"` Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"` PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"` } `json:"timeline"`
} }

View file

@ -10,30 +10,14 @@ import (
func TestNewSyncTokenWithLogs(t *testing.T) { func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{ tests := map[string]*StreamingToken{
"s4_0": { "s4_0_0_0_0": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}}, PDUPosition: 4,
logs: make(map[string]*LogPosition),
}, },
"s4_0.dl-0-123": { "s4_0_0_0_0.dl-0-123": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}}, PDUPosition: 4,
logs: map[string]*LogPosition{ DeviceListPosition: LogPosition{
"dl": { Partition: 0,
Partition: 0, Offset: 123,
Offset: 123,
},
},
},
"s4_0.ab-1-14419482332.dl-0-123": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"ab": {
Partition: 1,
Offset: 14419482332,
},
"dl": {
Partition: 0,
Offset: 123,
},
}, },
}, },
} }
@ -56,16 +40,22 @@ func TestNewSyncTokenWithLogs(t *testing.T) {
} }
} }
func TestNewSyncTokenFromString(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]syncToken{ shouldPass := map[string]string{
"s4_0": NewStreamToken(4, 0, nil).syncToken, "s4_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, LogPosition{}}.String(),
"s3_1": NewStreamToken(3, 1, nil).syncToken, "s3_1_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, LogPosition{1, 2}}.String(),
"t3_1": NewTopologyToken(3, 1).syncToken, "s3_1_2_3_5": StreamingToken{3, 1, 2, 3, 5, LogPosition{}}.String(),
"t3_1": TopologyToken{3, 1}.String(),
}
for a, b := range shouldPass {
if a != b {
t.Errorf("expected %q, got %q", a, b)
}
} }
shouldFail := []string{ shouldFail := []string{
"", "",
"s_1",
"s_", "s_",
"a3_4", "a3_4",
"b", "b",
@ -74,19 +64,15 @@ func TestNewSyncTokenFromString(t *testing.T) {
"2", "2",
} }
for test, expected := range shouldPass { for _, f := range append(shouldFail, "t1_2") {
result, _, err := newSyncTokenFromString(test) if _, err := NewStreamTokenFromString(f); err == nil {
if err != nil { t.Errorf("NewStreamTokenFromString %q should have failed", f)
t.Error(err)
}
if result.String() != expected.String() {
t.Errorf("%s expected %v but got %v", test, expected.String(), result.String())
} }
} }
for _, test := range shouldFail { for _, f := range append(shouldFail, "s1_2_3_4") {
if _, _, err := newSyncTokenFromString(test); err == nil { if _, err := NewTopologyTokenFromString(f); err == nil {
t.Errorf("input '%v' should have errored but didn't", test) t.Errorf("NewTopologyTokenFromString %q should have failed", f)
} }
} }
} }

View file

@ -141,18 +141,14 @@ New users appear in /keys/changes
Local delete device changes appear in v2 /sync Local delete device changes appear in v2 /sync
Local new device changes appear in v2 /sync Local new device changes appear in v2 /sync
Local update device changes appear in v2 /sync Local update device changes appear in v2 /sync
Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id Local device key changes get to remote servers with correct prev_id
Server correctly handles incoming m.device_list_update Server correctly handles incoming m.device_list_update
Device deletion propagates over federation
If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one If a device list update goes missing, the server resyncs on the next one
Get left notifs in sync and /keys/changes when other user leaves
Can query remote device keys using POST after notification
Server correctly resyncs when client query keys and there is no remote cache Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room Server correctly resyncs when server leaves and rejoins a room
Device list doesn't change if remote server is down Device list doesn't change if remote server is down
@ -503,3 +499,4 @@ Forgetting room does not show up in v2 /sync
Can forget room you've been kicked from Can forget room you've been kicked from
/whois /whois
/joined_members return joined members /joined_members return joined members
A next_batch token can be used in the v1 messages API