Merge branch 'master' into matthew/peeking-over-fed

This commit is contained in:
Neil Alexander 2020-12-18 13:18:10 +00:00 committed by GitHub
commit 8508af345e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1059 additions and 603 deletions

View file

@ -185,6 +185,7 @@ linters:
- gocyclo
- goimports # Does everything gofmt does
- gosimple
- govet
- ineffassign
- megacheck
- misspell # Check code comments, whereas misspell in CI checks *.md files

View file

@ -19,4 +19,4 @@ fi
go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/...
GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs
GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs

View file

@ -77,7 +77,7 @@ global:
# Naffka database options. Not required when using Kafka.
naffka_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -98,7 +98,7 @@ app_service_api:
connect: http://appservice_api:7777
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -173,7 +173,7 @@ federation_sender:
connect: http://federation_sender:7775
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_federationsender?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -199,7 +199,7 @@ key_server:
connect: http://key_server:7779
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_keyserver?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -212,7 +212,7 @@ media_api:
listen: http://0.0.0.0:8074
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_mediaapi?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -248,7 +248,7 @@ room_server:
connect: http://room_server:7770
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_roomserver?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -259,7 +259,7 @@ signing_key_server:
connect: http://signing_key_server:7780
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_signingkeyserver?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -288,7 +288,7 @@ sync_api:
listen: http://0.0.0.0:8073
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -299,12 +299,12 @@ user_api:
connect: http://user_api:7781
account_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable
max_open_conns: 100
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1

View file

@ -24,8 +24,6 @@ fi
echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first
# TODO: Once go 1.13 is out, use go get's -mod=readonly option
# https://github.com/golang/go/issues/30667
cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.19.1

View file

@ -17,6 +17,7 @@ package routing
import (
"net/http"
"sync"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -27,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -40,6 +42,25 @@ var (
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
)
func init() {
prometheus.MustRegister(sendEventDuration)
}
var sendEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "clientapi",
Name: "sendevent_duration_millis",
Help: "How long it takes to build and submit a new event from the client API to the roomserver",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"action"},
)
// SendEvent implements:
// /rooms/{roomID}/send/{eventType}
// /rooms/{roomID}/send/{eventType}/{txnID}
@ -75,10 +96,12 @@ func SendEvent(
mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
if resErr != nil {
return *resErr
}
timeToGenerateEvent := time.Since(startedGeneratingEvent)
var txnAndSessionID *api.TransactionID
if txnID != nil {
@ -90,6 +113,7 @@ func SendEvent(
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents(
req.Context(), rsAPI,
api.KindNew,
@ -102,6 +126,7 @@ func SendEvent(
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
}
timeToSubmitEvent := time.Since(startedSubmittingEvent)
util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": e.EventID(),
"room_id": roomID,
@ -117,6 +142,11 @@ func SendEvent(
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
}
// Take a note of how long it took to generate the event vs submit
// it to the roomserver.
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res
}

View file

@ -65,6 +65,8 @@ func main() {
cfg.FederationSender.DisableTLSValidation = true
cfg.MSCs.MSCs = []string{"msc2836"}
cfg.Logging[0].Level = "trace"
// don't hit matrix.org when running tests!!!
cfg.SigningKeyServer.KeyPerspectives = config.KeyPerspectives{}
}
j, err := yaml.Marshal(cfg)

View file

@ -80,12 +80,6 @@ brew services start kafka
## Configuration
### SQLite database setup
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### PostgreSQL database setup
Assuming that PostgreSQL 9.6 (or later) is installed:
@ -96,7 +90,23 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
sudo -u postgres createuser -P dendrite
```
* Create the component databases:
At this point you have a choice on whether to run all of the Dendrite
components from a single database, or for each component to have its
own database. For most deployments, running from a single database will
be sufficient, although you may wish to separate them if you plan to
split out the databases across multiple machines in the future.
On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run all Dendrite components from a single database:
```bash
sudo -u postgres createdb -O dendrite dendrite
```
... in which case your connection string will look like `postgres://user:pass@database/dendrite`.
* If you want to run each Dendrite component with its own database:
```bash
for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do
@ -104,14 +114,22 @@ Assuming that PostgreSQL 9.6 (or later) is installed:
done
```
(On macOS, omit `sudo -u postgres` from the above commands.)
... in which case your connection string will look like `postgres://user:pass@database/dendrite_componentname`.
### SQLite database setup
**WARNING:** SQLite is suitable for small experimental deployments only and should not be used in production - use PostgreSQL instead for any user-facing federating installation!
Dendrite can use the built-in SQLite database engine for small setups.
The SQLite databases do not need to be pre-built - Dendrite will
create them automatically at startup.
### Server key generation
Each Dendrite installation requires:
- A unique Matrix signing private key
- A valid and trusted TLS certificate and private key
* A unique Matrix signing private key
* A valid and trusted TLS certificate and private key
To generate a Matrix signing private key:
@ -119,7 +137,7 @@ To generate a Matrix signing private key:
./bin/generate-keys --private-key matrix_key.pem
```
**Warning:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or
**WARNING:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or
any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining
federated rooms.
@ -140,7 +158,9 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th
* The `server_name` entry to reflect the hostname of your Dendrite server
* The `database` lines with an updated connection string based on your
desired setup, e.g. replacing `database` with the name of the database:
* For Postgres: `postgres://dendrite:password@localhost/database`, e.g. `postgres://dendrite:password@localhost/dendrite_userapi_account.db`
* For Postgres: `postgres://dendrite:password@localhost/database`, e.g.
* `postgres://dendrite:password@localhost/dendrite_userapi_account` to connect to PostgreSQL with SSL/TLS
* `postgres://dendrite:password@localhost/dendrite_userapi_account?sslmode=disable` to connect to PostgreSQL without SSL/TLS
* For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db`
* Postgres and SQLite can be mixed and matched on different components as desired.
* The `use_naffka` option if using Naffka in a monolith deployment
@ -295,4 +315,3 @@ amongst other things.
```bash
./bin/dendrite-polylith-multi --config=dendrite.yaml userapi
```

View file

@ -242,6 +242,8 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CAS(false, true) {
return
}
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database
@ -295,10 +297,14 @@ func (oq *destinationQueue) backgroundSend() {
// time.
duration := time.Until(*until)
log.Warnf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction.

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
@ -45,6 +46,37 @@ type OutgoingQueues struct {
queues map[gomatrixserverlib.ServerName]*destinationQueue
}
func init() {
prometheus.MustRegister(
destinationQueueTotal, destinationQueueRunning,
destinationQueueBackingOff,
)
}
var destinationQueueTotal = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_total",
},
)
var destinationQueueRunning = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_running",
},
)
var destinationQueueBackingOff = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationsender",
Name: "destination_queues_backing_off",
},
)
// NewOutgoingQueues makes a new OutgoingQueues
func NewOutgoingQueues(
db storage.Database,
@ -116,6 +148,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
defer oqs.queuesMutex.Unlock()
oq := oqs.queues[destination]
if oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{
db: oqs.db,
rsAPI: oqs.rsAPI,

View file

@ -0,0 +1,45 @@
package caching
import (
"github.com/matrix-org/dendrite/roomserver/types"
)
// WARNING: This cache is mutable because it's entirely possible that
// the IsStub or StateSnaphotNID fields can change, even though the
// room version and room NID fields will not. This is only safe because
// the RoomInfoCache is used ONLY within the roomserver and because it
// will be kept up-to-date by the latest events updater. It MUST NOT be
// used from other components as we currently have no way to invalidate
// the cache in downstream components.
const (
RoomInfoCacheName = "roominfo"
RoomInfoCacheMaxEntries = 1024
RoomInfoCacheMutable = true
)
// RoomInfosCache contains the subset of functions needed for
// a room Info cache. It must only be used from the roomserver only
// It is not safe for use from other components.
type RoomInfoCache interface {
GetRoomInfo(roomID string) (roomInfo types.RoomInfo, ok bool)
StoreRoomInfo(roomID string, roomInfo types.RoomInfo)
}
// GetRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) GetRoomInfo(roomID string) (types.RoomInfo, bool) {
val, found := c.RoomInfos.Get(roomID)
if found && val != nil {
if roomInfo, ok := val.(types.RoomInfo); ok {
return roomInfo, true
}
}
return types.RoomInfo{}, false
}
// StoreRoomInfo must only be called from the roomserver only. It is not
// safe for use from other components.
func (c Caches) StoreRoomInfo(roomID string, roomInfo types.RoomInfo) {
c.RoomInfos.Set(roomID, roomInfo)
}

View file

@ -15,10 +15,6 @@ const (
RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomNIDsCacheName = "roomserver_room_nids"
RoomServerRoomNIDsCacheMaxEntries = 1024
RoomServerRoomNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false
@ -27,6 +23,7 @@ const (
type RoomServerCaches interface {
RoomServerNIDsCache
RoomVersionCache
RoomInfoCache
}
// RoomServerNIDsCache contains the subset of functions needed for
@ -38,9 +35,6 @@ type RoomServerNIDsCache interface {
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomNID(roomID string) (types.RoomNID, bool)
StoreRoomServerRoomNID(roomID string, nid types.RoomNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
}
@ -73,21 +67,6 @@ func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTyp
c.RoomServerEventTypeNIDs.Set(eventType, nid)
}
func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) {
val, found := c.RoomServerRoomNIDs.Get(roomID)
if found && val != nil {
if roomNID, ok := val.(types.RoomNID); ok {
return roomNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerRoomNID(roomID string, roomNID types.RoomNID) {
c.RoomServerRoomNIDs.Set(roomID, roomNID)
c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil {
@ -99,5 +78,5 @@ func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
}
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
c.StoreRoomServerRoomNID(roomID, roomNID)
c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID)
}

View file

@ -10,6 +10,7 @@ type Caches struct {
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache
}

View file

@ -45,19 +45,19 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil {
return nil, err
}
roomServerRoomNIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomNIDsCacheName,
RoomServerRoomNIDsCacheMutable,
RoomServerRoomNIDsCacheMaxEntries,
roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable,
RoomServerRoomIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable,
RoomServerRoomIDsCacheMaxEntries,
roomInfos, err := NewInMemoryLRUCachePartition(
RoomInfoCacheName,
RoomInfoCacheMutable,
RoomInfoCacheMaxEntries,
enablePrometheus,
)
if err != nil {
@ -77,8 +77,8 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomNIDs: roomServerRoomNIDs,
RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents,
}, nil
}

View file

@ -82,6 +82,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64
}
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil {
return nil, 0, err

View file

@ -83,6 +83,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64
}
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil {
return nil, 0, err

View file

@ -54,10 +54,8 @@ type inputWorker struct {
input chan *inputTask
}
// Guarded by a CAS on w.running
func (w *inputWorker) start() {
if !w.running.CAS(false, true) {
return
}
defer w.running.Store(false)
for {
select {
@ -118,7 +116,7 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
// InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents(
ctx context.Context,
_ context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
@ -142,7 +140,7 @@ func (r *Inputer) InputRoomEvents(
// room - the channel will be quite small as it's just pointer types.
w, _ := r.workers.LoadOrStore(roomID, &inputWorker{
r: r,
input: make(chan *inputTask, 10),
input: make(chan *inputTask, 32),
})
worker := w.(*inputWorker)
@ -150,13 +148,15 @@ func (r *Inputer) InputRoomEvents(
// the wait group, so that the worker can notify us when this specific
// task has been finished.
tasks[i] = &inputTask{
ctx: ctx,
ctx: context.Background(),
event: &request.InputRoomEvents[i],
wg: wg,
}
// Send the task to the worker.
go worker.start()
if worker.running.CAS(false, true) {
go worker.start()
}
worker.input <- tasks[i]
}

View file

@ -120,8 +120,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)"
type eventStatements struct {
insertEventStmt *sql.Stmt
@ -137,7 +137,7 @@ type eventStatements struct {
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -161,7 +161,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db)
}
@ -432,11 +432,24 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil
}
func (s *eventStatements) SelectRoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
return
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
}
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {

View file

@ -18,7 +18,6 @@ package postgres
import (
"context"
"database/sql"
"errors"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@ -69,8 +68,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid = ANY($1)"
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -90,7 +89,7 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomVersionsForRoomNIDsStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
bulkSelectRoomIDsStmt *sql.Stmt
@ -109,7 +108,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
{&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
@ -219,15 +218,24 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err
}
func (s *roomStatements) SelectRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
if err == sql.ErrNoRows {
return roomVersion, errors.New("room not found")
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil {
return nil, err
}
return roomVersion, err
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
}
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
@ -271,3 +279,11 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []strin
}
return roomNIDs, nil
}
func roomNIDsAsArray(roomNIDs []types.RoomNID) pq.Int64Array {
nids := make([]int64, len(roomNIDs))
for i := range roomNIDs {
nids[i] = int64(roomNIDs[i])
}
return nids
}

View file

@ -105,6 +105,13 @@ func (u *LatestEventsUpdater) SetLatestEvents(
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}

View file

@ -124,7 +124,15 @@ func (d *Database) StateEntriesForTuples(
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.RoomsTable.SelectRoomInfo(ctx, roomID)
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
}
return roomInfo, err
}
func (d *Database) AddState(
@ -313,25 +321,39 @@ func (d *Database) Events(
if err != nil {
eventIDs = map[types.EventNID]string{}
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID)
if err != nil {
return nil, err
}
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
roomVersion, _ = d.Cache.GetRoomVersion(roomID)
}
if roomVersion == "" {
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
if err != nil {
return nil, err
}
uniqueRoomNIDs := make(map[types.RoomNID]struct{})
for _, n := range roomNIDs {
uniqueRoomNIDs[n] = struct{}{}
}
roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
for n := range uniqueRoomNIDs {
if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
roomVersions[n] = roomInfo.RoomVersion
continue
}
}
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
if err != nil {
return nil, err
}
for n, v := range dbRoomVersions {
roomVersions[n] = v
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID := roomNIDs[result.EventNID]
roomVersion := roomVersions[roomNID]
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
)
@ -552,8 +574,8 @@ func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
if roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID); ok {
return roomNID, nil
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return roomInfo.RoomNID, nil
}
// Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
@ -565,9 +587,6 @@ func (d *Database) assignRoomNID(
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
}
}
if err == nil {
d.Cache.StoreRoomServerRoomNID(roomID, roomNID)
}
return roomNID, err
}

View file

@ -95,8 +95,8 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDsForEventNIDsSQL = "" +
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
type eventStatements struct {
db *sql.DB
@ -112,7 +112,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
@ -137,7 +137,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db)
}
@ -480,11 +480,33 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
return result, nil
}
func (s *eventStatements) SelectRoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
return
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID)
for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err
}
result[eventNID] = roomNID
}
return result, nil
}
func eventNIDsAsArray(eventNIDs []types.EventNID) string {

View file

@ -19,7 +19,6 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
@ -60,8 +59,8 @@ const selectLatestEventNIDsForUpdateSQL = "" +
const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
@ -82,9 +81,9 @@ type roomStatements struct {
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
//selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -101,7 +100,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
}.Prepare(db)
@ -223,15 +222,33 @@ func (s *roomStatements) UpdateLatestEventNIDs(
return err
}
func (s *roomStatements) SelectRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion
err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
if err == sql.ErrNoRows {
return roomVersion, errors.New("room not found")
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
return roomVersion, err
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
}
return result, nil
}
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {

View file

@ -10,8 +10,9 @@ import (
)
type EventJSONPair struct {
EventNID types.EventNID
EventJSON []byte
EventNID types.EventNID
RoomVersion gomatrixserverlib.RoomVersion
EventJSON []byte
}
type EventJSON interface {
@ -58,7 +59,7 @@ type Events interface {
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error)
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}
type Rooms interface {
@ -67,7 +68,7 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)

View file

@ -92,7 +92,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data")
}
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil))
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.StreamingToken{PDUPosition: pduPos})
return nil
}

View file

@ -88,7 +88,7 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro
return err
}
// update stream position
s.notifier.OnNewReceipt(types.NewStreamToken(0, streamPos, nil))
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return nil
}

View file

@ -94,10 +94,8 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
"event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice()
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
streamPos, err := s.db.StoreNewSendForDeviceMessage(
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent,
)
if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message")
@ -107,7 +105,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
s.notifier.OnNewSendToDevice(
output.UserID,
[]string{output.DeviceID},
types.NewStreamToken(0, streamPos, nil),
types.StreamingToken{SendToDevicePosition: streamPos},
)
return nil

View file

@ -64,10 +64,7 @@ func NewOutputTypingEventConsumer(
// Start consuming from EDU api
func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(
nil, roomID, nil,
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
)
s.notifier.OnNewTyping(roomID, types.StreamingToken{TypingPosition: types.StreamPosition(latestSyncPosition)})
})
return s.typingConsumer.Start()
@ -95,6 +92,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
}
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil))
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
return nil
}

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
syncinternal "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
syncapi "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
@ -114,12 +113,12 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
return err
}
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{
syncinternal.DeviceListLogName: {
posUpdate := types.StreamingToken{
DeviceListPosition: types.LogPosition{
Offset: msg.Offset,
Partition: msg.Partition,
},
})
}
for userID := range queryRes.UserIDsToCount {
s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
}

View file

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return err
}
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.notifier.OnNewEvent(ev, "", nil, types.StreamingToken{PDUPosition: pduPos})
return nil
}
@ -220,7 +220,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
return err
}
s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.notifier.OnNewEvent(ev, "", nil, types.StreamingToken{PDUPosition: pduPos})
return nil
}
@ -259,6 +259,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent,
) error {
if msg.Event.StateKey() == nil {
log.WithFields(log.Fields{
"event": string(msg.Event.JSON()),
}).Panicf("roomserver output log: invite has no state key")
return nil
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil {
// panic rather than continue with an inconsistent database
@ -269,14 +275,14 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure")
return nil
}
s.notifier.OnNewEvent(msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, *msg.Event.StateKey())
return nil
}
func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent,
) error {
sp, err := s.db.RetireInviteEvent(ctx, msg.EventID)
pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
@ -287,7 +293,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
}
// Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil))
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
return nil
}
@ -307,7 +313,7 @@ func (s *OutputRoomEventConsumer) onNewPeek(
// we need to wake up the users who might need to now be peeking into this room,
// so we send in a dummy event to trigger a wakeup
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil))
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp})
return nil
}
@ -327,7 +333,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
// we need to wake up the users who might need to now be peeking into this room,
// so we send in a dummy event to trigger a wakeup
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil))
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp})
return nil
}

View file

@ -73,15 +73,13 @@ func DeviceListCatchup(
offset = sarama.OffsetOldest
// Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := from.Log(DeviceListLogName)
if logOffset != nil {
partition = logOffset.Partition
offset = logOffset.Offset
if !from.DeviceListPosition.IsEmpty() {
partition = from.DeviceListPosition.Partition
offset = from.DeviceListPosition.Offset
}
var toOffset int64
toOffset = sarama.OffsetNewest
toLog := to.Log(DeviceListLogName)
if toLog != nil && toLog.Offset > 0 {
if toLog := to.DeviceListPosition; toLog.Partition == partition && toLog.Offset > 0 {
toOffset = toLog.Offset
}
var queryRes api.QueryKeyChangesResponse
@ -130,11 +128,11 @@ func DeviceListCatchup(
}
}
// set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{
to.DeviceListPosition = types.LogPosition{
Partition: queryRes.Partition,
Offset: queryRes.Offset,
})
res.NextBatch = to.String()
}
res.NextBatch.ApplyUpdates(to)
return hasNew, nil
}

View file

@ -16,13 +16,13 @@ import (
var (
syncingUser = "@alice:localhost"
emptyToken = types.NewStreamToken(0, 0, nil)
newestToken = types.NewStreamToken(0, 0, map[string]*types.LogPosition{
DeviceListLogName: {
emptyToken = types.StreamingToken{}
newestToken = types.StreamingToken{
DeviceListPosition: types.LogPosition{
Offset: sarama.OffsetNewest,
Partition: 0,
},
})
}
)
type mockKeyAPI struct{}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -49,9 +50,10 @@ type messagesReq struct {
}
type messagesResp struct {
Start string `json:"start"`
End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
Start string `json:"start"`
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
}
const defaultMessagesLimit = 10
@ -65,6 +67,7 @@ func OnIncomingMessagesRequest(
federation *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI,
srp *sync.RequestPool,
) util.JSONResponse {
var err error
@ -84,9 +87,18 @@ func OnIncomingMessagesRequest(
// Extract parameters from the request's URL.
// Pagination tokens.
var fromStream *types.StreamingToken
from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from"))
fromQuery := req.URL.Query().Get("from")
emptyFromSupplied := fromQuery == ""
if emptyFromSupplied {
// NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided.
// We do this to allow clients to get messages without having to call `/sync` e.g Cerulean
currPos := srp.Notifier.CurrentPosition()
fromQuery = currPos.String()
}
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil {
fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from"))
fs, err2 := types.NewStreamTokenFromString(fromQuery)
fromStream = &fs
if err2 != nil {
return util.JSONResponse{
@ -185,14 +197,19 @@ func OnIncomingMessagesRequest(
"return_end": end.String(),
}).Info("Responding")
res := messagesResp{
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
}
if emptyFromSupplied {
res.StartStream = fromStream.String()
}
// Respond with the events.
return util.JSONResponse{
Code: http.StatusOK,
JSON: messagesResp{
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
},
JSON: res,
}
}
@ -381,7 +398,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate {
// We've hit the beginning of the room so there's really nowhere else
// to go. This seems to fix Riot iOS from looping on /messages endlessly.
end = types.NewTopologyToken(0, 0)
end = types.TopologyToken{}
} else {
end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
@ -447,11 +464,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
// The condition in the SQL query is a strict "greater than" so
// we need to check against to-1.
streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos)
isSetLargeEnough = (r.to.PDUPosition-1 == streamPos)
}
} else {
streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos)
isSetLargeEnough = (r.from.PDUPosition-1 == streamPos)
}
}
@ -565,7 +582,7 @@ func setToDefault(
if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.NewTopologyToken(0, 0)
to = types.TopologyToken{}
} else {
to, err = db.MaxTopologicalPosition(ctx, roomID)
}

View file

@ -51,7 +51,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg)
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",

View file

@ -130,9 +130,9 @@ type Database interface {
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after

View file

@ -58,6 +58,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
`
const upsertRoomStateSQL = "" +

View file

@ -0,0 +1,66 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Use the new syncapi_receipts_id sequence.
CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
ALTER SEQUENCE IF EXISTS syncapi_receipt_id RESTART WITH 1;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_receipt_id');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
-- Revert back to using the syncapi_stream_id sequence.
DROP SEQUENCE IF EXISTS syncapi_receipt_id;
ALTER TABLE syncapi_receipts ALTER COLUMN id SET DEFAULT nextval('syncapi_stream_id');
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -30,11 +30,12 @@ import (
)
const receiptsSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id;
CREATE SEQUENCE IF NOT EXISTS syncapi_receipt_id;
-- Stores data about receipts
CREATE TABLE IF NOT EXISTS syncapi_receipts (
-- The ID
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_receipt_id'),
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
@ -50,18 +51,22 @@ const upsertReceipt = "" +
" (room_id, receipt_type, user_id, event_id, receipt_ts)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (room_id, receipt_type, user_id)" +
" DO UPDATE SET id = nextval('syncapi_stream_id'), event_id = $4, receipt_ts = $5" +
" DO UPDATE SET id = nextval('syncapi_receipt_id'), event_id = $4, receipt_ts = $5" +
" RETURNING id"
const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
"SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE room_id = ANY($1) AND id > $2"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct {
db *sql.DB
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
}
func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
@ -78,6 +83,9 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil
}
@ -87,20 +95,37 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return
}
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := types.StreamPosition(0)
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err)
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent
for rows.Next() {
r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
}
res = append(res, r)
if id > lastPos {
lastPos = id
}
}
return res, rows.Err()
return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}

View file

@ -49,6 +49,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
RETURNING id
`
const countSendToDeviceMessagesSQL = `
@ -107,8 +108,8 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
) (pos types.StreamPosition, err error) {
err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
return
}
@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
@ -152,9 +153,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
}
}
events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
}
return events, rows.Err()
return lastPos, events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
)
@ -36,6 +37,7 @@ type SyncServerDatasource struct {
}
// NewDatabase creates a new sync server database
// nolint:gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
var d SyncServerDatasource
var err error
@ -86,6 +88,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -78,8 +78,8 @@ func (d *Database) GetEventsInStreamingRange(
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
r := types.Range{
From: from.PDUPosition(),
To: to.PDUPosition(),
From: from.PDUPosition,
To: to.PDUPosition,
Backwards: backwardOrdering,
}
if backwardOrdering {
@ -391,16 +391,16 @@ func (d *Database) GetEventsInTopologicalRange(
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
if backwardOrdering {
// Backward ordering means the 'from' token has a higher depth than the 'to' token
minDepth = to.Depth()
maxDepth = from.Depth()
minDepth = to.Depth
maxDepth = from.Depth
// for cases where we have say 5 events with the same depth, the TopologyToken needs to
// know which of the 5 the client has seen. This is done by using the PDU position.
// Events with the same maxDepth but less than this PDU position will be returned.
maxStreamPosForMaxDepth = from.PDUPosition()
maxStreamPosForMaxDepth = from.PDUPosition
} else {
// Forward ordering means the 'from' token has a lower depth than the 'to' token.
minDepth = from.Depth()
maxDepth = to.Depth()
minDepth = from.Depth
maxDepth = to.Depth
}
// Select the event IDs from the defined range.
@ -440,9 +440,9 @@ func (d *Database) MaxTopologicalPosition(
) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
if err != nil {
return types.NewTopologyToken(0, 0), err
return types.TopologyToken{}, err
}
return types.NewTopologyToken(depth, streamPos), nil
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
}
func (d *Database) EventPositionInTopology(
@ -450,9 +450,9 @@ func (d *Database) EventPositionInTopology(
) (types.TopologyToken, error) {
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID)
if err != nil {
return types.NewTopologyToken(0, 0), err
return types.TopologyToken{}, err
}
return types.NewTopologyToken(depth, stream), nil
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
}
func (d *Database) syncPositionTx(
@ -483,7 +483,17 @@ func (d *Database) syncPositionTx(
if maxPeekID > maxEventID {
maxEventID = maxPeekID
}
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil)
maxReceiptID, err := d.Receipts.SelectMaxReceiptID(ctx, txn)
if err != nil {
return sp, err
}
// TODO: complete these positions
sp = types.StreamingToken{
PDUPosition: types.StreamPosition(maxEventID),
TypingPosition: types.StreamPosition(d.EDUCache.GetLatestSyncPosition()),
ReceiptPosition: types.StreamPosition(maxReceiptID),
InvitePosition: types.StreamPosition(maxInviteID),
}
return
}
@ -534,11 +544,6 @@ func (d *Database) addPDUDeltaToResponse(
}
}
// TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
succeeded = true
return joinedRoomIDs, nil
}
@ -555,7 +560,7 @@ func (d *Database) addTypingDeltaToResponse(
for _, roomID := range joinedRoomIDs {
var jr types.JoinResponse
if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter(
roomID, int64(since.EDUPosition()),
roomID, int64(since.TypingPosition),
); updated {
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping,
@ -574,6 +579,7 @@ func (d *Database) addTypingDeltaToResponse(
res.Rooms.Join[roomID] = jr
}
}
res.NextBatch.TypingPosition = types.StreamPosition(d.EDUCache.GetLatestSyncPosition())
return nil
}
@ -584,7 +590,7 @@ func (d *Database) addReceiptDeltaToResponse(
joinedRoomIDs []string,
res *types.Response,
) error {
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition())
lastPos, receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.ReceiptPosition)
if err != nil {
return fmt.Errorf("unable to select receipts for rooms: %w", err)
}
@ -629,6 +635,7 @@ func (d *Database) addReceiptDeltaToResponse(
res.Rooms.Join[roomID] = jr
}
res.NextBatch.ReceiptPosition = lastPos
return nil
}
@ -639,7 +646,7 @@ func (d *Database) addEDUDeltaToResponse(
joinedRoomIDs []string,
res *types.Response,
) error {
if fromPos.EDUPosition() != toPos.EDUPosition() {
if fromPos.TypingPosition != toPos.TypingPosition {
// add typing deltas
if err := d.addTypingDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil {
return fmt.Errorf("unable to apply typing delta to response: %w", err)
@ -647,8 +654,8 @@ func (d *Database) addEDUDeltaToResponse(
}
// Check on initial sync and if EDUPositions differ
if (fromPos.EDUPosition() == 0 && toPos.EDUPosition() == 0) ||
fromPos.EDUPosition() != toPos.EDUPosition() {
if (fromPos.ReceiptPosition == 0 && toPos.ReceiptPosition == 0) ||
fromPos.ReceiptPosition != toPos.ReceiptPosition {
if err := d.addReceiptDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil {
return fmt.Errorf("unable to apply receipts to response: %w", err)
}
@ -682,15 +689,14 @@ func (d *Database) IncrementalSync(
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos)
res.NextBatch = nextBatchPos.String()
res.NextBatch = fromPos.WithUpdates(toPos)
var joinedRoomIDs []string
var err error
if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState {
if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
r := types.Range{
From: fromPos.PDUPosition(),
To: toPos.PDUPosition(),
From: fromPos.PDUPosition,
To: toPos.PDUPosition,
}
joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, r, numRecentEventsPerRoom, wantFullState, res,
@ -716,6 +722,14 @@ func (d *Database) IncrementalSync(
return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err)
}
ir := types.Range{
From: fromPos.InvitePosition,
To: toPos.InvitePosition,
}
if err = d.addInvitesToResponse(ctx, nil, device.UserID, ir, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
return res, nil
}
@ -772,10 +786,14 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
}
r := types.Range{
From: 0,
To: toPos.PDUPosition(),
To: toPos.PDUPosition,
}
ir := types.Range{
From: 0,
To: toPos.InvitePosition,
}
res.NextBatch = toPos.String()
res.NextBatch.ApplyUpdates(toPos)
// Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -815,7 +833,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
}
}
if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil {
if err = d.addInvitesToResponse(ctx, txn, userID, ir, res); err != nil {
return
}
@ -875,16 +893,18 @@ func (d *Database) getJoinResponseForCompleteSync(
// Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology.
var prevBatchStr string
var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil {
return
}
prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos)
prevBatch = &types.TopologyToken{
Depth: backwardTopologyPos,
PDUPosition: backwardStreamPos,
}
prevBatch.Decrement()
prevBatchStr = prevBatch.String()
}
// We don't include a device here as we don't need to send down
@ -893,7 +913,7 @@ func (d *Database) getJoinResponseForCompleteSync(
recentEvents := d.StreamEventsToEvents(&device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr
jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@ -915,7 +935,7 @@ func (d *Database) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res,
types.StreamingToken{}, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err)
@ -965,7 +985,7 @@ func (d *Database) getBackwardTopologyPos(
ctx context.Context, txn *sql.Tx,
events []types.StreamEvent,
) (types.TopologyToken, error) {
zeroToken := types.NewTopologyToken(0, 0)
zeroToken := types.TopologyToken{}
if len(events) == 0 {
return zeroToken, nil
}
@ -973,7 +993,7 @@ func (d *Database) getBackwardTopologyPos(
if err != nil {
return zeroToken, err
}
tok := types.NewTopologyToken(pos, spos)
tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
tok.Decrement()
return tok, nil
}
@ -1021,7 +1041,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String()
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1029,7 +1049,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String()
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1040,7 +1060,7 @@ func (d *Database) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = prevBatch.String()
lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1361,39 +1381,40 @@ func (d *Database) SendToDeviceUpdatesWaiting(
}
func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) {
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (newPos types.StreamPosition, err error) {
j, err := json.Marshal(event)
if err != nil {
return streamPos, err
return 0, err
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SendToDevice.InsertSendToDeviceMessage(
newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, string(j),
)
return err
})
if err != nil {
return streamPos, err
return 0, err
}
return streamPos, nil
return 0, nil
}
func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,
token types.StreamingToken,
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}
// If there's nothing to do then stop here.
if len(events) == 0 {
return nil, nil, nil, nil
return 0, nil, nil, nil, nil
}
// Work out whether we need to update any of the database entries.
@ -1420,7 +1441,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
}
}
return toReturn, toUpdate, toDelete, nil
return lastPos, toReturn, toUpdate, toDelete, nil
}
func (d *Database) CleanSendToDeviceUpdates(
@ -1507,5 +1528,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
}
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
return receipts, err
}

View file

@ -46,6 +46,8 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
`
const upsertRoomStateSQL = "" +

View file

@ -0,0 +1,58 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
UPDATE syncapi_stream_id SET stream_id=1 WHERE stream_name="receipt";
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -51,15 +51,19 @@ const upsertReceipt = "" +
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
"SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE id > $1 and room_id in ($2)"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
}
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
@ -77,12 +81,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil
}
// UpsertReceipt creates new user receipts
func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
pos, err = r.streamIDStatements.nextStreamID(ctx, txn)
pos, err = r.streamIDStatements.nextReceiptID(ctx, txn)
if err != nil {
return
}
@ -92,9 +99,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
}
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := types.StreamPosition(0)
params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos
for k, v := range roomIDs {
@ -102,17 +109,33 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
}
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err)
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent
for rows.Next() {
r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
}
res = append(res, r)
if id > lastPos {
lastPos = id
}
}
return res, rows.Err()
return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}

View file

@ -100,8 +100,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
) (pos types.StreamPosition, err error) {
var result sql.Result
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
if p, err := result.LastInsertId(); err != nil {
return 0, err
} else {
pos = types.StreamPosition(p)
}
return
}
@ -117,7 +123,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
@ -145,9 +151,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
}
}
events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
}
return events, rows.Err()
return lastPos, events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -18,6 +18,8 @@ CREATE TABLE IF NOT EXISTS syncapi_stream_id (
);
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
@ -56,3 +58,13 @@ func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
)
// SyncServerDatasource represents a sync server datasource which manages
@ -46,13 +47,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
return nil, err
}
d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil {
if err = d.prepare(dbProperties); err != nil {
return nil, err
}
return &d, nil
}
func (d *SyncServerDatasource) prepare() (err error) {
// nolint:gocyclo
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err
}
@ -99,6 +101,11 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -165,9 +165,9 @@ func TestSyncResponse(t *testing.T) {
{
Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0), nil,
)
from := types.StreamingToken{ // pretend we are at the penultimate event
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
},
@ -178,9 +178,9 @@ func TestSyncResponse(t *testing.T) {
{
Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0), nil,
)
from := types.StreamingToken{ // pretend we are 10 events behind
PDUPosition: positions[len(positions)-11],
}
res := types.NewResponse()
// limit is set to 5
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -222,8 +222,13 @@ func TestSyncResponse(t *testing.T) {
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil)
if res.NextBatch != next.String() {
next := types.StreamingToken{
PDUPosition: latest.PDUPosition,
TypingPosition: latest.TypingPosition,
ReceiptPosition: latest.ReceiptPosition,
SendToDevicePosition: latest.SendToDevicePosition,
}
if res.NextBatch.String() != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
}
roomRes, ok := res.Rooms.Join[testRoomID]
@ -245,9 +250,9 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.NewStreamToken(
positions[len(positions)-2], types.StreamPosition(0), nil,
)
from := types.StreamingToken{
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -261,7 +266,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch
prev := roomRes.Timeline.PrevBatch.String()
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
@ -271,7 +276,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.NewTopologyToken(0, 0)
to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
@ -291,7 +296,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.NewStreamToken(0, 0, nil)
to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
@ -313,7 +318,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
// head towards the beginning of time
to := types.NewTopologyToken(0, 0)
to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
@ -382,7 +387,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
}
// head towards the beginning of time
to := types.NewTopologyToken(0, 0)
to := types.TopologyToken{}
testCases := []struct {
Name string
@ -458,7 +463,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
// head towards the beginning of time
to := types.NewTopologyToken(0, 0)
to := types.TopologyToken{}
// Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths,
// allowing this bug to surface.
@ -508,7 +513,7 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
}
// head towards the beginning of time
to := types.NewTopologyToken(0, 0)
to := types.TopologyToken{}
// starting at `from`, backpaginate to the beginning of time, asserting as we go.
chunkSize = 3
@ -534,20 +539,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything
// yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil))
_, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
if err != nil {
return
}
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob",
Type: "m.type",
Content: json.RawMessage("{}"),
@ -559,14 +564,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil {
return
}
@ -574,35 +579,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil))
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil))
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
if err != nil {
t.Fatal(err)
}
@ -639,7 +644,7 @@ func TestInviteBehaviour(t *testing.T) {
}
// both invite events should appear in a new sync
beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
@ -654,19 +659,15 @@ func TestInviteBehaviour(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err)
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.StreamingToken{}, latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
assertInvitedToRooms(t, res, []string{inviteRoom2})
// a sync after we have received both invites should result in a leave for the retired room
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
if err != nil {
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
}
res = types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false)
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireRes.NextBatch, latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}

View file

@ -146,8 +146,8 @@ type BackwardsExtremities interface {
// sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
@ -160,5 +160,6 @@ type Filter interface {
type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}

View file

@ -77,9 +77,8 @@ func (n *Notifier) OnNewEvent(
// This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.currPos.ApplyUpdates(posUpdate)
n.removeEmptyUserStreams()
if ev != nil {
@ -113,11 +112,11 @@ func (n *Notifier) OnNewEvent(
}
}
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos)
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
} else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos)
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
} else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, nil, latestPos)
n.wakeupUsers(userIDs, nil, n.currPos)
} else {
log.WithFields(log.Fields{
"posUpdate": posUpdate.String,
@ -155,20 +154,33 @@ func (n *Notifier) OnNewSendToDevice(
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos)
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUserDevice(userID, deviceIDs, n.currPos)
}
// OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt(
func (n *Notifier) OnNewTyping(
roomID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
}
// OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt(
roomID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
}
func (n *Notifier) OnNewKeyChange(
@ -176,9 +188,19 @@ func (n *Notifier) OnNewKeyChange(
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUsers([]string{wakeUserID}, nil, latestPos)
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
func (n *Notifier) OnNewInvite(
posUpdate types.StreamingToken, wakeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
// GetListener returns a UserStreamListener that can be used to wait for

View file

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld = types.NewStreamToken(5, 0, nil)
syncPositionBefore = types.NewStreamToken(11, 0, nil)
syncPositionAfter = types.NewStreamToken(12, 0, nil)
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil)
syncPositionAfter2 = types.NewStreamToken(13, 0, nil)
syncPositionVeryOld = types.StreamingToken{PDUPosition: 5}
syncPositionBefore = types.StreamingToken{PDUPosition: 11}
syncPositionAfter = types.StreamingToken{PDUPosition: 12}
//syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition, 1, 0, 0, nil)
syncPositionAfter2 = types.StreamingToken{PDUPosition: 13}
)
var (
@ -205,6 +205,9 @@ func TestNewInviteEventForUser(t *testing.T) {
}
// Test an EDU-only update wakes up the request.
// TODO: Fix this test, invites wake up with an incremented
// PDU position, not EDU position
/*
func TestEDUWakeup(t *testing.T) {
n := NewNotifier(syncPositionAfter)
n.setUsersJoinedToRooms(map[string][]string{
@ -229,6 +232,7 @@ func TestEDUWakeup(t *testing.T) {
wg.Wait()
}
*/
// Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) {
@ -331,7 +335,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
return types.StreamingToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
)
case <-listener.GetNotifyChannel(*req.since):
case <-listener.GetNotifyChannel(req.since):
p := listener.GetSyncPosition()
return p, nil
}
@ -361,7 +365,7 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syn
ID: deviceID,
},
timeout: 1 * time.Minute,
since: &since,
since: since,
wantFullState: false,
limit: DefaultTimelineLimit,
log: util.GetLogger(context.TODO()),

View file

@ -46,7 +46,7 @@ type syncRequest struct {
device userapi.Device
limit int
timeout time.Duration
since *types.StreamingToken // nil means that no since token was supplied
since types.StreamingToken // nil means that no since token was supplied
wantFullState bool
log *log.Entry
}
@ -55,18 +55,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false"
var since *types.StreamingToken
sinceStr := req.URL.Query().Get("since")
since, sinceStr := types.StreamingToken{}, req.URL.Query().Get("since")
if sinceStr != "" {
tok, err := types.NewStreamTokenFromString(sinceStr)
var err error
since, err = types.NewStreamTokenFromString(sinceStr)
if err != nil {
return nil, err
}
since = &tok
}
if since == nil {
tok := types.NewStreamToken(0, 0, nil)
since = &tok
}
timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too

View file

@ -35,6 +35,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
)
@ -43,7 +44,7 @@ type RequestPool struct {
db storage.Database
cfg *config.SyncAPI
userAPI userapi.UserInternalAPI
notifier *Notifier
Notifier *Notifier
keyAPI keyapi.KeyInternalAPI
rsAPI roomserverAPI.RoomserverInternalAPI
lastseen sync.Map
@ -99,6 +100,30 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device)
rp.lastseen.Store(device.UserID+device.ID, time.Now())
}
func init() {
prometheus.MustRegister(
activeSyncRequests, waitingSyncRequests,
)
}
var activeSyncRequests = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "active_sync_requests",
Help: "The number of sync requests that are active right now",
},
)
var waitingSyncRequests = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "waiting_sync_requests",
Help: "The number of sync requests that are waiting to be woken by a notifier",
},
)
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
// called in a dedicated goroutine for this request. This function will block the goroutine
// until a response is ready, or it times out.
@ -122,9 +147,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
"limit": syncReq.limit,
})
activeSyncRequests.Inc()
defer activeSyncRequests.Dec()
rp.updateLastSeen(req, device)
currPos := rp.notifier.CurrentPosition()
currPos := rp.Notifier.CurrentPosition()
if rp.shouldReturnImmediately(syncReq) {
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
@ -139,13 +167,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
}
}
waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec()
// Otherwise, we wait for the notifier to tell us if something *may* have
// happened. We loop in case it turns out that nothing did happen.
timer := time.NewTimer(syncReq.timeout) // case of timeout=0 is handled above
defer timer.Stop()
userStreamListener := rp.notifier.GetListener(*syncReq)
userStreamListener := rp.Notifier.GetListener(*syncReq)
defer userStreamListener.Close()
// We need the loop in case userStreamListener wakes up even if there isn't
@ -154,13 +185,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// respond with, so we skip the return an go back to waiting for content to
// be sent down or the request timing out.
var hasTimedOut bool
sincePos := *syncReq.since
sincePos := syncReq.since
for {
select {
// Wait for notifier to wake us up
case <-userStreamListener.GetNotifyChannel(sincePos):
currPos = userStreamListener.GetSyncPosition()
sincePos = currPos
// Or for timeout to expire
case <-timer.C:
// We just need to ensure we get out of the select after reaching the
@ -248,30 +278,30 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
res := types.NewResponse()
// See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since)
lastPos, events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, req.since)
if err != nil {
return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err)
}
// TODO: handle ignored users
if req.since.PDUPosition() == 0 && req.since.EDUPosition() == 0 {
if req.since.IsEmpty() {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
if err != nil {
return res, fmt.Errorf("rp.db.CompleteSync: %w", err)
}
} else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, req.since, latestPos, req.limit, req.wantFullState)
if err != nil {
return res, fmt.Errorf("rp.db.IncrementalSync: %w", err)
}
}
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter)
if err != nil {
return res, fmt.Errorf("rp.appendAccountData: %w", err)
}
res, err = rp.appendDeviceLists(res, req.device.UserID, *req.since, latestPos)
res, err = rp.appendDeviceLists(res, req.device.UserID, req.since, latestPos)
if err != nil {
return res, fmt.Errorf("rp.appendDeviceLists: %w", err)
}
@ -285,7 +315,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
// Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since)
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.since)
if err != nil {
return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err)
}
@ -295,15 +325,9 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
}
// Get the next_batch from the sync response and increase the
// EDU counter.
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.Positions[1]++
res.NextBatch = pos.String()
}
}
res.NextBatch.SendToDevicePosition = lastPos
return res, err
}
@ -328,7 +352,7 @@ func (rp *RequestPool) appendAccountData(
// data keys were set between two message. This isn't a huge issue since the
// duplicate data doesn't represent a huge quantity of data, but an optimisation
// here would be making sure each data is sent only once to the client.
if req.since == nil || (req.since.PDUPosition() == 0 && req.since.EDUPosition() == 0) {
if req.since.IsEmpty() {
// If this is the initial sync, we don't need to check if a data has
// already been sent. Instead, we send the whole batch.
dataReq := &userapi.QueryAccountDataRequest{
@ -363,7 +387,7 @@ func (rp *RequestPool) appendAccountData(
}
r := types.Range{
From: req.since.PDUPosition(),
From: req.since.PDUPosition,
To: currentPos,
}
// If both positions are the same, it means that the data was saved after the
@ -433,7 +457,7 @@ func (rp *RequestPool) appendAccountData(
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.
func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState {
if syncReq.since.IsEmpty() || syncReq.timeout == 0 || syncReq.wantFullState {
return true
}
waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)

View file

@ -16,9 +16,7 @@ package types
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
"strings"
@ -46,6 +44,10 @@ type LogPosition struct {
Offset int64
}
func (p *LogPosition) IsEmpty() bool {
return p.Offset == 0
}
// IsAfter returns true if this position is after `lp`.
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
if lp == nil {
@ -107,108 +109,125 @@ const (
)
type StreamingToken struct {
syncToken
logs map[string]*LogPosition
PDUPosition StreamPosition
TypingPosition StreamPosition
ReceiptPosition StreamPosition
SendToDevicePosition StreamPosition
InvitePosition StreamPosition
DeviceListPosition LogPosition
}
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
if t.logs == nil {
t.logs = make(map[string]*LogPosition)
}
t.logs[name] = lp
// This will be used as a fallback by json.Marshal.
func (s StreamingToken) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
func (t *StreamingToken) Log(name string) *LogPosition {
l, ok := t.logs[name]
if !ok {
return nil
}
return l
// This will be used as a fallback by json.Unmarshal.
func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
*s, err = NewStreamTokenFromString(string(text))
return err
}
func (t *StreamingToken) PDUPosition() StreamPosition {
return t.Positions[0]
}
func (t *StreamingToken) EDUPosition() StreamPosition {
return t.Positions[1]
}
func (t *StreamingToken) String() string {
var logStrings []string
for name, lp := range t.logs {
logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset)
logStrings = append(logStrings, logStr)
func (t StreamingToken) String() string {
posStr := fmt.Sprintf(
"s%d_%d_%d_%d_%d",
t.PDUPosition, t.TypingPosition,
t.ReceiptPosition, t.SendToDevicePosition,
t.InvitePosition,
)
if dl := t.DeviceListPosition; !dl.IsEmpty() {
posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset)
}
sort.Strings(logStrings)
// E.g s11_22_33.dl0-134.ab1-441
return strings.Join(append([]string{t.syncToken.String()}, logStrings...), ".")
return posStr
}
// IsAfter returns true if ANY position in this token is greater than `other`.
func (t *StreamingToken) IsAfter(other StreamingToken) bool {
for i := range other.Positions {
if t.Positions[i] > other.Positions[i] {
return true
}
}
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
if t.logs[name].IsAfter(otherLog) {
return true
}
switch {
case t.PDUPosition > other.PDUPosition:
return true
case t.TypingPosition > other.TypingPosition:
return true
case t.ReceiptPosition > other.ReceiptPosition:
return true
case t.SendToDevicePosition > other.SendToDevicePosition:
return true
case t.InvitePosition > other.InvitePosition:
return true
case t.DeviceListPosition.IsAfter(&other.DeviceListPosition):
return true
}
return false
}
func (t *StreamingToken) IsEmpty() bool {
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition == 0 && t.DeviceListPosition.IsEmpty()
}
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// If the latter StreamingToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
// If the other token has a log, they will replace any existing log on this token.
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
ret.Type = t.Type
ret.Positions = make([]StreamPosition, len(t.Positions))
for i := range t.Positions {
ret.Positions[i] = t.Positions[i]
if other.Positions[i] == 0 {
continue
}
ret.Positions[i] = other.Positions[i]
}
ret.logs = make(map[string]*LogPosition)
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
copy := *otherLog
ret.logs[name] = &copy
}
func (t *StreamingToken) WithUpdates(other StreamingToken) StreamingToken {
ret := *t
ret.ApplyUpdates(other)
return ret
}
type TopologyToken struct {
syncToken
// ApplyUpdates applies any changes from the supplied StreamingToken. If the supplied
// streaming token contains any positions that are not 0, they are considered updates
// and will overwrite the value in the token.
func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.PDUPosition > 0 {
t.PDUPosition = other.PDUPosition
}
if other.TypingPosition > 0 {
t.TypingPosition = other.TypingPosition
}
if other.ReceiptPosition > 0 {
t.ReceiptPosition = other.ReceiptPosition
}
if other.SendToDevicePosition > 0 {
t.SendToDevicePosition = other.SendToDevicePosition
}
if other.InvitePosition > 0 {
t.InvitePosition = other.InvitePosition
}
if other.DeviceListPosition.Offset > 0 {
t.DeviceListPosition = other.DeviceListPosition
}
}
func (t *TopologyToken) Depth() StreamPosition {
return t.Positions[0]
type TopologyToken struct {
Depth StreamPosition
PDUPosition StreamPosition
}
func (t *TopologyToken) PDUPosition() StreamPosition {
return t.Positions[1]
// This will be used as a fallback by json.Marshal.
func (t TopologyToken) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
}
// This will be used as a fallback by json.Unmarshal.
func (t *TopologyToken) UnmarshalText(text []byte) (err error) {
*t, err = NewTopologyTokenFromString(string(text))
return err
}
func (t *TopologyToken) StreamToken() StreamingToken {
return NewStreamToken(t.PDUPosition(), 0, nil)
return StreamingToken{
PDUPosition: t.PDUPosition,
}
}
func (t *TopologyToken) String() string {
return t.syncToken.String()
func (t TopologyToken) String() string {
return fmt.Sprintf("t%d_%d", t.Depth, t.PDUPosition)
}
// Decrement the topology token to one event earlier.
func (t *TopologyToken) Decrement() {
depth := t.Positions[0]
pduPos := t.Positions[1]
depth := t.Depth
pduPos := t.PDUPosition
if depth-1 <= 0 {
// nothing can be lower than this
depth = 1
@ -223,151 +242,95 @@ func (t *TopologyToken) Decrement() {
if depth < 1 {
depth = 1
}
t.Positions = []StreamPosition{
depth, pduPos,
}
t.Depth = depth
t.PDUPosition = pduPos
}
// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x"
// represents the type of a pagination token and "yyyy..." the token itself, and
// parses it in order to create a new instance of SyncToken. Returns an
// error if the token couldn't be parsed into an int64, or if the token type
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
// case).
func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) {
if len(s) == 0 {
return nil, nil, ErrInvalidSyncTokenLen
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
if len(tok) < 1 {
err = fmt.Errorf("empty topology token")
return
}
token = new(syncToken)
var positions []string
switch t := SyncTokenType(s[:1]); t {
case SyncTokenTypeStream, SyncTokenTypeTopology:
token.Type = t
categories = strings.Split(s[1:], ".")
positions = strings.Split(categories[0], "_")
default:
return nil, nil, ErrInvalidSyncTokenType
if tok[0] != SyncTokenTypeTopology[0] {
err = fmt.Errorf("topology token must start with 't'")
return
}
for _, pos := range positions {
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
return nil, nil, err
} else if posInt < 0 {
return nil, nil, errors.New("negative position not allowed")
} else {
token.Positions = append(token.Positions, StreamPosition(posInt))
parts := strings.Split(tok[1:], "_")
var positions [2]StreamPosition
for i, p := range parts {
if i > len(positions) {
break
}
var pos int
pos, err = strconv.Atoi(p)
if err != nil {
return
}
positions[i] = StreamPosition(pos)
}
token = TopologyToken{
Depth: positions[0],
PDUPosition: positions[1],
}
return
}
// NewTopologyToken creates a new sync token for /messages
func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
if depth < 0 {
depth = 1
}
return TopologyToken{
syncToken: syncToken{
Type: SyncTokenTypeTopology,
Positions: []StreamPosition{depth, streamPos},
},
}
}
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
t, _, err := newSyncTokenFromString(tok)
if err != nil {
return
}
if t.Type != SyncTokenTypeTopology {
err = fmt.Errorf("token %s is not a topology token", tok)
return
}
if len(t.Positions) < 2 {
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
return
}
return TopologyToken{
syncToken: *t,
}, nil
}
// NewStreamToken creates a new sync token for /sync
func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken {
if logs == nil {
logs = make(map[string]*LogPosition)
}
return StreamingToken{
syncToken: syncToken{
Type: SyncTokenTypeStream,
Positions: []StreamPosition{pduPos, eduPos},
},
logs: logs,
}
}
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
t, categories, err := newSyncTokenFromString(tok)
if err != nil {
if len(tok) < 1 {
err = fmt.Errorf("empty stream token")
return
}
if t.Type != SyncTokenTypeStream {
err = fmt.Errorf("token %s is not a streaming token", tok)
if tok[0] != SyncTokenTypeStream[0] {
err = fmt.Errorf("stream token must start with 's'")
return
}
if len(t.Positions) < 2 {
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
return
categories := strings.Split(tok[1:], ".")
parts := strings.Split(categories[0], "_")
var positions [5]StreamPosition
for i, p := range parts {
if i > len(positions) {
break
}
var pos int
pos, err = strconv.Atoi(p)
if err != nil {
return
}
positions[i] = StreamPosition(pos)
}
logs := make(map[string]*LogPosition)
if len(categories) > 1 {
// dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("token %s - invalid log: %s", tok, logStr)
token = StreamingToken{
PDUPosition: positions[0],
TypingPosition: positions[1],
ReceiptPosition: positions[2],
SendToDevicePosition: positions[3],
InvitePosition: positions[4],
}
// dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("invalid log position %q", logStr)
return
}
switch segments[0] {
case "dl":
// Device list syncing
var partition, offset int
if partition, err = strconv.Atoi(segments[1]); err != nil {
return
}
var partition int64
partition, err = strconv.ParseInt(segments[1], 10, 32)
if err != nil {
if offset, err = strconv.Atoi(segments[2]); err != nil {
return
}
var offset int64
offset, err = strconv.ParseInt(segments[2], 10, 64)
if err != nil {
return
}
logs[segments[0]] = &LogPosition{
Partition: int32(partition),
Offset: offset,
}
token.DeviceListPosition.Partition = int32(partition)
token.DeviceListPosition.Offset = int64(offset)
default:
err = fmt.Errorf("unrecognised token type %q", segments[0])
return
}
}
return StreamingToken{
syncToken: *t,
logs: logs,
}, nil
}
// syncToken represents a syncapi token, used for interactions with
// /sync or /messages, for example.
type syncToken struct {
Type SyncTokenType
// A list of stream positions, their meanings vary depending on the token type.
Positions []StreamPosition
}
// String translates a SyncToken to a string of the "xyyyy..." (see
// NewSyncToken to know what it represents).
func (p *syncToken) String() string {
posStr := make([]string, len(p.Positions))
for i := range p.Positions {
posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10)
}
return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_"))
return token, nil
}
// PrevEventRef represents a reference to a previous event in a state event upgrade
@ -379,7 +342,7 @@ type PrevEventRef struct {
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct {
NextBatch string `json:"next_batch"`
NextBatch StreamingToken `json:"next_batch"`
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data,omitempty"`
@ -443,7 +406,7 @@ type JoinResponse struct {
Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"`
PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"`
Ephemeral struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -501,7 +464,7 @@ type LeaveResponse struct {
Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"`
PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"`
}

View file

@ -10,30 +10,14 @@ import (
func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{
"s4_0": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: make(map[string]*LogPosition),
"s4_0_0_0_0": {
PDUPosition: 4,
},
"s4_0.dl-0-123": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"dl": {
Partition: 0,
Offset: 123,
},
},
},
"s4_0.ab-1-14419482332.dl-0-123": {
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"ab": {
Partition: 1,
Offset: 14419482332,
},
"dl": {
Partition: 0,
Offset: 123,
},
"s4_0_0_0_0.dl-0-123": {
PDUPosition: 4,
DeviceListPosition: LogPosition{
Partition: 0,
Offset: 123,
},
},
}
@ -56,16 +40,22 @@ func TestNewSyncTokenWithLogs(t *testing.T) {
}
}
func TestNewSyncTokenFromString(t *testing.T) {
shouldPass := map[string]syncToken{
"s4_0": NewStreamToken(4, 0, nil).syncToken,
"s3_1": NewStreamToken(3, 1, nil).syncToken,
"t3_1": NewTopologyToken(3, 1).syncToken,
func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{
"s4_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, LogPosition{}}.String(),
"s3_1_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, LogPosition{1, 2}}.String(),
"s3_1_2_3_5": StreamingToken{3, 1, 2, 3, 5, LogPosition{}}.String(),
"t3_1": TopologyToken{3, 1}.String(),
}
for a, b := range shouldPass {
if a != b {
t.Errorf("expected %q, got %q", a, b)
}
}
shouldFail := []string{
"",
"s_1",
"s_",
"a3_4",
"b",
@ -74,19 +64,15 @@ func TestNewSyncTokenFromString(t *testing.T) {
"2",
}
for test, expected := range shouldPass {
result, _, err := newSyncTokenFromString(test)
if err != nil {
t.Error(err)
}
if result.String() != expected.String() {
t.Errorf("%s expected %v but got %v", test, expected.String(), result.String())
for _, f := range append(shouldFail, "t1_2") {
if _, err := NewStreamTokenFromString(f); err == nil {
t.Errorf("NewStreamTokenFromString %q should have failed", f)
}
}
for _, test := range shouldFail {
if _, _, err := newSyncTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test)
for _, f := range append(shouldFail, "s1_2_3_4") {
if _, err := NewTopologyTokenFromString(f); err == nil {
t.Errorf("NewTopologyTokenFromString %q should have failed", f)
}
}
}

View file

@ -141,18 +141,14 @@ New users appear in /keys/changes
Local delete device changes appear in v2 /sync
Local new device changes appear in v2 /sync
Local update device changes appear in v2 /sync
Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id
Server correctly handles incoming m.device_list_update
Device deletion propagates over federation
If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one
Get left notifs in sync and /keys/changes when other user leaves
Can query remote device keys using POST after notification
Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room
Device list doesn't change if remote server is down
@ -503,3 +499,4 @@ Forgetting room does not show up in v2 /sync
Can forget room you've been kicked from
/whois
/joined_members return joined members
A next_batch token can be used in the v1 messages API