Merge branch 'main' into s7evink/docker

This commit is contained in:
Neil Alexander 2022-10-03 15:43:39 +01:00 committed by GitHub
commit 61aaad087e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
104 changed files with 3507 additions and 2429 deletions

View file

@ -1,5 +1,38 @@
# Changelog
## Dendrite 0.10.1 (2022-09-30)
### Features
* The built-in NATS Server has been updated to version 2.9.2
### Fixes
* A regression introduced in 0.10.0 in `/sync` as a result of transaction errors has been fixed
* Account data updates will no longer send duplicate output events
## Dendrite 0.10.0 (2022-09-30)
### Features
* High performance full-text searching has been added to Dendrite
* Search must be enabled in the [`search` section of the `sync_api` config](https://github.com/matrix-org/dendrite/blob/6348486a1365c7469a498101f5035a9b6bd16d22/dendrite-sample.monolith.yaml#L279-L290) before it can be used
* The search index is stored on the filesystem rather than the sync API database, so a path to a suitable storage location on disk must be configured
* Sync requests should now complete faster and use considerably less database connections as a result of better transactional isolation
* The notifications code has been refactored to hopefully make notifications more reliable
* A new `/_dendrite/admin/refreshDevices/{userID}` admin endpoint has been added for forcing a refresh of a remote user's device lists without having to modify the database by hand
* A new `/_dendrite/admin/fulltext/reindex` admin endpoint has been added for rebuilding the search index (although this may take some time)
### Fixes
* A number of bugs in the device list updater have been fixed, which should help considerably with federated device list synchronisation and E2EE reliability
* A state resolution bug has been fixed which should help to prevent unexpected state resets
* The deprecated `"origin"` field in events will now be correctly ignored in all cases
* Room versions 8 and 9 will now correctly evaluate `"knock"` join rules and membership states
* A database index has been added to speed up finding room memberships in the sync API (contributed by [PiotrKozimor](https://github.com/PiotrKozimor))
* The client API will now return an `M_UNRECOGNIZED` error for unknown endpoints/methods, which should help with client error handling
* A bug has been fixed when updating push rules which could result in `database is locked` on SQLite
## Dendrite 0.9.9 (2022-09-22)
### Features

View file

@ -30,6 +30,8 @@ import (
"sync"
"time"
"go.uber.org/atomic"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -66,6 +68,7 @@ const (
PeerTypeRemote = pineconeRouter.PeerTypeRemote
PeerTypeMulticast = pineconeRouter.PeerTypeMulticast
PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth
PeerTypeBonjour = pineconeRouter.PeerTypeBonjour
)
type DendriteMonolith struct {
@ -82,6 +85,10 @@ type DendriteMonolith struct {
userAPI userapiAPI.UserInternalAPI
}
func (m *DendriteMonolith) PublicKey() string {
return m.PineconeRouter.PublicKey().String()
}
func (m *DendriteMonolith) BaseURL() string {
return fmt.Sprintf("http://%s", m.listener.Addr().String())
}
@ -94,6 +101,20 @@ func (m *DendriteMonolith) SessionCount() int {
return len(m.PineconeQUIC.Protocol("matrix").Sessions())
}
func (m *DendriteMonolith) RegisterNetworkInterface(name string, index int, mtu int, up bool, broadcast bool, loopback bool, pointToPoint bool, multicast bool, addrs string) {
m.PineconeMulticast.RegisterInterface(pineconeMulticast.InterfaceInfo{
Name: name,
Index: index,
Mtu: mtu,
Up: up,
Broadcast: broadcast,
Loopback: loopback,
PointToPoint: pointToPoint,
Multicast: multicast,
Addrs: addrs,
})
}
func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
if enabled {
m.PineconeMulticast.Start()
@ -105,7 +126,9 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
func (m *DendriteMonolith) SetStaticPeer(uri string) {
m.PineconeManager.RemovePeers()
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
for _, uri := range strings.Split(uri, ",") {
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
}
}
func (m *DendriteMonolith) DisconnectType(peertype int) {
@ -134,32 +157,21 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error)
go func() {
conduit.portMutex.Lock()
defer conduit.portMutex.Unlock()
loop:
for i := 1; i <= 10; i++ {
logrus.Errorf("Attempting authenticated connect (attempt %d)", i)
var err error
conduit.port, err = m.PineconeRouter.Connect(
l,
pineconeRouter.ConnectionZone(zone),
pineconeRouter.ConnectionPeerType(peertype),
)
switch err {
case io.ErrClosedPipe:
logrus.Errorf("Authenticated connect failed due to closed pipe (attempt %d)", i)
return
case io.EOF:
logrus.Errorf("Authenticated connect failed due to EOF (attempt %d)", i)
break loop
case nil:
logrus.Errorf("Authenticated connect succeeded, connected to port %d (attempt %d)", conduit.port, i)
return
default:
logrus.WithError(err).Errorf("Authenticated connect failed (attempt %d)", i)
time.Sleep(time.Second)
}
logrus.Errorf("Attempting authenticated connect")
var err error
if conduit.port, err = m.PineconeRouter.Connect(
l,
pineconeRouter.ConnectionZone(zone),
pineconeRouter.ConnectionPeerType(peertype),
); err != nil {
logrus.Errorf("Authenticated connect failed: %s", err)
_ = l.Close()
_ = r.Close()
_ = conduit.Close()
return
}
_ = l.Close()
_ = r.Close()
logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port)
}()
return conduit, nil
}
@ -269,19 +281,21 @@ func (m *DendriteMonolith) Start() {
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.Global.PrivateKey = sk
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.InMemory = true
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.Global.JetStream.InMemory = false
cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(m.CacheDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media"))
cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media"))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
cfg.SyncAPI.Fulltext.Enabled = true
cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(m.CacheDirectory, "search"))
if err = cfg.Derive(); err != nil {
panic(err)
}
@ -395,6 +409,7 @@ func (m *DendriteMonolith) Stop() {
const MaxFrameSize = types.MaxFrameSize
type Conduit struct {
closed atomic.Bool
conn net.Conn
port types.SwitchPortID
portMutex sync.Mutex
@ -407,10 +422,16 @@ func (c *Conduit) Port() int {
}
func (c *Conduit) Read(b []byte) (int, error) {
if c.closed.Load() {
return 0, io.EOF
}
return c.conn.Read(b)
}
func (c *Conduit) ReadCopy() ([]byte, error) {
if c.closed.Load() {
return nil, io.EOF
}
var buf [65535 * 2]byte
n, err := c.conn.Read(buf[:])
if err != nil {
@ -420,9 +441,16 @@ func (c *Conduit) ReadCopy() ([]byte, error) {
}
func (c *Conduit) Write(b []byte) (int, error) {
if c.closed.Load() {
return 0, io.EOF
}
return c.conn.Write(b)
}
func (c *Conduit) Close() error {
if c.closed.Load() {
return io.ErrClosedPipe
}
c.closed.Store(true)
return c.conn.Close()
}

View file

@ -2,16 +2,23 @@ package routing
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
@ -138,3 +145,49 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
},
}
}
func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse {
_, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10)
if err != nil {
logrus.WithError(err).Error("failed to publish nats message")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.ClientKeyAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
userID := vars["userID"]
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if domain == cfg.Matrix.ServerName {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam("Can not mark local device list as stale"),
}
}
err = keyAPI.PerformMarkAsStaleIfNeeded(req.Context(), &api.PerformMarkAsStaleRequest{
UserID: userID,
Domain: domain,
}, &struct{}{})
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)),
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -20,6 +20,12 @@ import (
"strings"
"github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth"
@ -34,11 +40,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
@ -161,6 +162,18 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/fulltext/reindex",
httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminReindex(req, cfg, device, natsClient)
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}",
httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminMarkAsStale(req, cfg, keyAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")

View file

@ -64,7 +64,7 @@ var (
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
)

View file

@ -89,6 +89,7 @@ func main() {
if configFlagSet {
cfg = setup.ParseFlags(true)
sk = cfg.Global.PrivateKey
pk = sk.Public().(ed25519.PublicKey)
} else {
keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem"
if _, err := os.Stat(keyfile); os.IsNotExist(err) {
@ -142,6 +143,9 @@ func main() {
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
cfg.MediaAPI.BasePath = config.Path(*instanceDir)
cfg.SyncAPI.Fulltext.Enabled = true
cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir)
if err := cfg.Derive(); err != nil {
panic(err)
}

View file

@ -134,6 +134,9 @@ func main() {
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
cfg.MediaAPI.BasePath = config.Path(*instanceDir)
cfg.SyncAPI.Fulltext.Enabled = true
cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir)
if err := cfg.Derive(); err != nil {
panic(err)
}

View file

@ -5,10 +5,11 @@ import (
"fmt"
"path/filepath"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v2"
"github.com/matrix-org/dendrite/setup/config"
)
func main() {
@ -82,6 +83,12 @@ func main() {
EnableInbound: true,
EnableOutbound: true,
}
cfg.SyncAPI.Fulltext = config.Fulltext{
Enabled: true,
IndexPath: config.Path(filepath.Join(*dirPath, "searchindex")),
InMemory: true,
Language: "en",
}
}
} else {
var err error

View file

@ -275,10 +275,19 @@ sync_api:
# address of the client. This is likely required if Dendrite is running behind
# a reverse proxy server.
# real_ip_header: X-Real-IP
fulltext:
# Configuration for the full-text search engine.
search:
# Whether or not search is enabled.
enabled: false
index_path: "./fulltextindex"
language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
# The path where the search index will be created in.
index_path: "./searchindex"
# The language most likely to be used on the server - used when indexing, to
# ensure the returned results match expectations. A full list of possible languages
# can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
language: "en"
# Configuration for the User API.
user_api:

View file

@ -326,10 +326,19 @@ sync_api:
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
fulltext:
# Configuration for the full-text search engine.
search:
# Whether or not search is enabled.
enabled: false
index_path: "./fulltextindex"
language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
# The path where the search index will be created in.
index_path: "./searchindex"
# The language most likely to be used on the server - used when indexing, to
# ensure the returned results match expectations. A full list of possible languages
# can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
language: "en"
# This option controls which HTTP header to inspect to find the real remote IP
# address of the client. This is likely required if Dendrite is running behind

View file

@ -1,3 +1,4 @@
---
title: Creating user accounts
parent: Administration
@ -31,11 +32,11 @@ To create a new **admin account**, add the `-admin` flag:
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
```
By default `create-account` uses `https://localhost:8448` to connect to Dendrite, this can be overwritten using
By default `create-account` uses `http://localhost:8008` to connect to Dendrite, this can be overwritten using
the `-url` flag:
```bash
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url http://localhost:8008
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url https://localhost:8448
```
An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`:

View file

@ -57,6 +57,16 @@ Request body format:
Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_dendrite/admin/fulltext/reindex`
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.
Indexing is done in the background, the server logs every 1000 events (or below) when they are being indexed. Once reindexing is done, you'll see something along the lines `Indexed 69586 events in 53.68223182s` in your debug logs.
## POST `/_dendrite/admin/refreshDevices/{userID}`
This endpoint instructs Dendrite to immediately query `/devices/{userID}` on a federated server. An empty JSON body will be returned on success, updating all locally stored user devices/keys. This can be used to possibly resolve E2EE issues, where the remote user can't decrypt messages.
## POST `/_synapse/admin/v1/send_server_notice`
Request body format:

View file

@ -87,6 +87,12 @@ and contain the following JSON document:
For example, this can be done with the following Caddy config:
```
handle /.well-known/matrix/server {
header Content-Type application/json
header Access-Control-Allow-Origin *
respond `"m.server": "matrix.example.com:8448"`
}
handle /.well-known/matrix/client {
header Content-Type application/json
header Access-Control-Allow-Origin *

View file

@ -138,16 +138,18 @@ room_server:
conn_max_lifetime: -1
```
## Fulltext search
## Full-text search
Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expections. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang).
Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows.
Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang).
```yaml
sync_api:
# ...
fulltext:
search:
enabled: false
index_path: "./fulltextindex"
index_path: "./searchindex"
language: "en"
```

View file

@ -80,7 +80,6 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats
return true
}
if originServerName != t.ServerName {
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
return true
}
// Extract the send-to-device event from msg.

14
go.mod
View file

@ -22,11 +22,11 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5
github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5
github.com/nats-io/nats-server/v2 v2.9.2
github.com/nats-io/nats.go v1.17.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
@ -43,7 +43,7 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.5-0.20220901155642-4f2abece817c
go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be
golang.org/x/image v0.0.0-20220902085622-e7cb96979f69
golang.org/x/mobile v0.0.0-20220722155234-aaac322e2105
golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1
@ -91,7 +91,7 @@ require (
github.com/h2non/filetype v1.1.3 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/juju/errors v1.0.0 // indirect
github.com/klauspost/compress v1.15.10 // indirect
github.com/klauspost/compress v1.15.11 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/lucas-clemente/quic-go v0.29.0 // indirect
github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect
@ -120,9 +120,9 @@ require (
go.etcd.io/bbolt v1.3.6 // indirect
golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect
golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45 // indirect
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect
golang.org/x/tools v0.1.12 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/macaroon.v2 v2.1.0 // indirect

28
go.sum
View file

@ -347,8 +347,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/klauspost/compress v1.15.10 h1:Ai8UzuomSCDw90e1qNMtb15msBXsNpH6gzkkENQNcJo=
github.com/klauspost/compress v1.15.10/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c=
github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -384,10 +384,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d h1:kGPJ6Rg8nn5an2CbCZrRiuTNyNzE0rRMiqm4UXJYrRs=
github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621 h1:a8IaoSPDxevkgXnOUrtIW9AqVNvXBJAG0gtnX687S7g=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c h1:iCHLYwwlPsf4TYFrvhKdhQoAM2lXzcmDZYqwBNWcnVk=
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
@ -422,8 +422,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5 h1:G/YGSXcJ2bUofD8Ts49it4VNezaJLQldI6fZR+wIUts=
github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5/go.mod h1:BWKY6217RvhI+FDoOLZ2BH+hOC37xeKRBlQ1Lz7teKI=
github.com/nats-io/nats-server/v2 v2.9.2 h1:XNDgJgOYYaYlquLdbSHI3xssLipfKUOq3EmYIMNCOsE=
github.com/nats-io/nats-server/v2 v2.9.2/go.mod h1:4sq8wvrpbvSzL1n3ZfEYnH4qeUuIl5W990j3kw13rRk=
github.com/nats-io/nats.go v1.17.0 h1:1jp5BThsdGlN91hW0k3YEfJbfACjiOYtUiLXG0RL4IE=
github.com/nats-io/nats.go v1.17.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
@ -625,8 +625,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 h1:a5Yg6ylndHHYJqIPrdq0AhvR6KTvDTAvgBtaidhEevY=
golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be h1:fmw3UbQh+nxngCAHrDCCztao/kbYFnWjoqop8dHx05A=
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -810,8 +810,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec h1:BkDtF2Ih9xZ7le9ndzTA7KJow28VbQW3odyk/8drmuI=
golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220919170432-7a66f970e087 h1:tPwmk4vmvVCMdr98VgL4JH+qZxPL8fqlUOHnyOM8N3w=
@ -829,8 +829,8 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45 h1:yuLAip3bfURHClMG9VBdzPrQvCWjWiWUTBGV+/fCbUs=
golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y=
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -22,8 +22,9 @@ import (
"github.com/blevesearch/bleve/v2"
"github.com/blevesearch/bleve/v2/mapping"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config"
)
// Search contains all existing bleve.Index

View file

@ -27,11 +27,11 @@ import (
func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search {
t.Helper()
cfg := config.Fulltext{}
cfg.Defaults(config.DefaultOpts{
Generate: true,
Monolithic: true,
})
cfg := config.Fulltext{
Enabled: true,
InMemory: true,
Language: "en",
}
if tempDir != "" {
cfg.IndexPath = config.Path(tempDir)
cfg.InMemory = false

View file

@ -16,8 +16,8 @@ var build string
const (
VersionMajor = 0
VersionMinor = 9
VersionPatch = 9
VersionMinor = 10
VersionPatch = 1
VersionTag = "" // example: "rc1"
)

View file

@ -45,6 +45,7 @@ type ClientKeyAPI interface {
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
// API functions required by the userapi

View file

@ -17,6 +17,7 @@ package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"net"
@ -31,6 +32,7 @@ import (
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
var (
@ -45,6 +47,9 @@ var (
)
)
const defaultWaitTime = time.Minute
const requestTimeout = time.Second * 30
func init() {
prometheus.MustRegister(
deviceListUpdateCount,
@ -80,6 +85,7 @@ func init() {
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
type DeviceListUpdater struct {
process *process.ProcessContext
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
// request to the remote server and race.
// TODO: Put in an LRU cache to bound growth
@ -131,10 +137,12 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
process: process,
userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{},
db: db,
@ -234,7 +242,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
"deleted": event.Deleted,
}).Info("DeviceListUpdater.Update")
}).Trace("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
if exists || event.Deleted {
@ -378,111 +386,123 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
}
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
requestTimeout := time.Second * 30 // max amount of time we want to spend on each request
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
ctx := u.process.Context()
logger := util.GetLogger(ctx).WithField("server_name", serverName)
waitTime := 2 * time.Second
// fetch stale device lists
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
waitTime := defaultWaitTime // How long should we wait to try again?
successCount := 0 // How many user requests failed?
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
if err != nil {
logger.WithError(err).Error("Failed to load stale device lists")
return waitTime, true
}
failCount := 0
userLoop:
defer func() {
for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure
u.clearChannel(userID)
}
}()
for _, userID := range userIDs {
if ctx.Err() != nil {
// we've timed out, give up and go to the back of the queue to let another server be processed.
failCount += 1
waitTime = time.Minute * 10
userWait, err := u.processServerUser(ctx, serverName, userID)
if err != nil {
if userWait > waitTime {
waitTime = userWait
}
break
}
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil {
failCount += 1
select {
case <-ctx.Done():
// we've timed out, give up and go to the back of the queue to let another server be processed.
waitTime = time.Minute * 10
break userLoop
default:
}
switch e := err.(type) {
case *fedsenderapi.FederationClientError:
if e.RetryAfter > 0 {
waitTime = e.RetryAfter
} else if e.Blacklisted {
waitTime = time.Hour * 8
break userLoop
} else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
waitTime = time.Hour
break userLoop
}
case net.Error:
// Use the default waitTime, if it's a timeout.
// It probably doesn't make sense to try further users.
if !e.Timeout() {
waitTime = time.Minute * 10
logger.WithError(e).Error("GetUserDevices returned net.Error")
break userLoop
}
case gomatrix.HTTPError:
// The remote server returned an error, give it some time to recover.
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 {
waitTime = time.Hour
logger.WithError(e).Error("GetUserDevices returned gomatrix.HTTPError")
break userLoop
}
default:
// Something else failed
waitTime = time.Minute * 10
logger.WithError(err).WithField("user_id", userID).Debugf("GetUserDevices returned unknown error type: %T", err)
break userLoop
}
continue
}
if res.MasterKey != nil || res.SelfSigningKey != nil {
uploadReq := &api.PerformUploadDeviceKeysRequest{
UserID: userID,
}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
if res.MasterKey != nil {
if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
uploadReq.MasterKey = *res.MasterKey
}
}
if res.SelfSigningKey != nil {
if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it")
failCount += 1
}
successCount++
}
if failCount > 0 {
allUsersSucceeded := successCount == len(userIDs)
if !allUsersSucceeded {
logger.WithFields(logrus.Fields{
"total": len(userIDs),
"failed": failCount,
"skipped": len(userIDs) - failCount,
"waittime": waitTime,
"total": len(userIDs),
"succeeded": successCount,
"failed": len(userIDs) - successCount,
"wait_time": waitTime,
}).Warn("Failed to query device keys for some users")
}
for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure
u.clearChannel(userID)
return waitTime, !allUsersSucceeded
}
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"server_name": serverName,
"user_id": userID,
})
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err
}
switch e := err.(type) {
case *json.UnmarshalTypeError, *json.SyntaxError:
logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
return defaultWaitTime, nil
case *fedsenderapi.FederationClientError:
if e.RetryAfter > 0 {
return e.RetryAfter, err
} else if e.Blacklisted {
return time.Hour * 8, err
} else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
return time.Hour, err
}
case net.Error:
// Use the default waitTime, if it's a timeout.
// It probably doesn't make sense to try further users.
if !e.Timeout() {
logger.WithError(e).Debug("GetUserDevices returned net.Error")
return time.Minute * 10, err
}
case gomatrix.HTTPError:
// The remote server returned an error, give it some time to recover.
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 {
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
return time.Hour, err
}
default:
// Something else failed
logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
return time.Minute * 10, err
}
}
return waitTime, failCount > 0
if res.UserID != userID {
logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
return defaultWaitTime, nil
}
if res.MasterKey != nil || res.SelfSigningKey != nil {
uploadReq := &api.PerformUploadDeviceKeysRequest{
UserID: userID,
}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
if res.MasterKey != nil {
if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
uploadReq.MasterKey = *res.MasterKey
}
}
if res.SelfSigningKey != nil {
if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {
logger.WithError(err).Error("Fetched device list but failed to store/emit it")
return defaultWaitTime, err
}
return defaultWaitTime, nil
}
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
var (
@ -146,7 +147,7 @@ func TestUpdateHavePrevID(t *testing.T) {
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, ap, producer, nil, 1)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@ -218,7 +219,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 2)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@ -287,7 +288,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}

View file

@ -228,14 +228,21 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
// in our database.
func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{req.DeviceID}, true)
knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
if err != nil {
return err
}
if len(knownDevices) == 0 {
return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
return nil // fmt.Errorf("unknown user %s", req.UserID)
}
return nil
for i := range knownDevices {
if knownDevices[i].DeviceID == req.DeviceID {
return nil // we already know about this device
}
}
return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
}
// nolint:gocyclo

View file

@ -58,7 +58,7 @@ func NewInternalAPI(
FedClient: fedClient,
Producer: keyChangeProducer,
}
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap.Updater = updater
go func() {
if err := updater.Start(); err != nil {

View file

@ -37,16 +37,13 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/gorilla/mux"
"github.com/kardianos/minwinsvc"
@ -61,6 +58,8 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api"
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
)
@ -392,17 +391,26 @@ func (b *BaseDendrite) configureHTTPErrors() {
_, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method)))
}
clientNotFoundHandler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell
}
notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler())
notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler))
for _, router := range []*mux.Router{
b.PublicClientAPIMux, b.PublicMediaAPIMux,
b.DendriteAdminMux, b.SynapseAdminMux,
b.PublicWellKnownAPIMux,
b.PublicMediaAPIMux, b.DendriteAdminMux,
b.SynapseAdminMux, b.PublicWellKnownAPIMux,
} {
router.NotFoundHandler = notFoundCORSHandler
router.MethodNotAllowedHandler = notAllowedCORSHandler
}
// Special case so that we don't upset clients on the CS API.
b.PublicClientAPIMux.NotFoundHandler = http.HandlerFunc(clientNotFoundHandler)
b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler)
}
// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on

View file

@ -10,7 +10,7 @@ type SyncAPI struct {
RealIPHeader string `yaml:"real_ip_header"`
Fulltext Fulltext `yaml:"fulltext"`
Fulltext Fulltext `yaml:"search"`
}
func (c *SyncAPI) Defaults(opts DefaultOpts) {
@ -50,18 +50,14 @@ type Fulltext struct {
func (f *Fulltext) Defaults(opts DefaultOpts) {
f.Enabled = false
f.IndexPath = "./fulltextindex"
f.IndexPath = "./searchindex"
f.Language = "en"
if opts.Generate {
f.Enabled = true
f.InMemory = true
}
}
func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) {
if !f.Enabled {
return
}
checkNotEmpty(configErrs, "syncapi.fulltext.index_path", string(f.IndexPath))
checkNotEmpty(configErrs, "syncapi.fulltext.language", f.Language)
checkNotEmpty(configErrs, "syncapi.search.index_path", string(f.IndexPath))
checkNotEmpty(configErrs, "syncapi.search.language", f.Language)
}

View file

@ -9,9 +9,10 @@ import (
"time"
"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server"
natsclient "github.com/nats-io/nats.go"
@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"},
OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"},
OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"},
} {
streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers {

View file

@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputPresenceEvent,
Retention: nats.InterestPolicy,

View file

@ -16,37 +16,42 @@ package consumers
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
// OutputClientDataConsumer consumes events that originated in the client API server.
type OutputClientDataConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
stream types.StreamProvider
notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
ctx context.Context
jetstream nats.JetStreamContext
nats *nats.Conn
durable string
topic string
topicReIndex string
db storage.Database
stream streams.StreamProvider
notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
fts *fulltext.Search
cfg *config.SyncAPI
}
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -54,26 +59,93 @@ func NewOutputClientDataConsumer(
process *process.ProcessContext,
cfg *config.SyncAPI,
js nats.JetStreamContext,
nats *nats.Conn,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
stream streams.StreamProvider,
fts *fulltext.Search,
) *OutputClientDataConsumer {
return &OutputClientDataConsumer{
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"),
db: store,
notifier: notifier,
stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
topicReIndex: cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex),
durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"),
nats: nats,
db: store,
notifier: notifier,
stream: stream,
serverName: cfg.Matrix.ServerName,
fts: fts,
cfg: cfg,
}
}
// Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error {
_, err := s.nats.Subscribe(s.topicReIndex, func(msg *nats.Msg) {
if err := msg.Ack(); err != nil {
return
}
if !s.cfg.Fulltext.Enabled {
logrus.Warn("Fulltext indexing is disabled")
return
}
ctx := context.Background()
logrus.Infof("Starting to index events")
var offset int
start := time.Now()
count := 0
var id int64 = 0
for {
evs, err := s.db.ReIndex(ctx, 1000, id)
if err != nil {
logrus.WithError(err).Errorf("unable to get events to index")
return
}
if len(evs) == 0 {
break
}
logrus.Debugf("Indexing %d events", len(evs))
elements := make([]fulltext.IndexElement, 0, len(evs))
for streamPos, ev := range evs {
id = streamPos
e := fulltext.IndexElement{
EventID: ev.EventID(),
RoomID: ev.RoomID(),
StreamPosition: streamPos,
}
e.SetContentType(ev.Type())
switch ev.Type() {
case "m.room.message":
e.Content = gjson.GetBytes(ev.Content(), "body").String()
case gomatrixserverlib.MRoomName:
e.Content = gjson.GetBytes(ev.Content(), "name").String()
case gomatrixserverlib.MRoomTopic:
e.Content = gjson.GetBytes(ev.Content(), "topic").String()
default:
continue
}
if strings.TrimSpace(e.Content) == "" {
continue
}
elements = append(elements, e)
}
if err = s.fts.Index(elements...); err != nil {
logrus.WithError(err).Error("unable to index events")
continue
}
offset += len(evs)
count += len(elements)
}
logrus.Infof("Indexed %d events in %v", count, time.Since(start))
})
if err != nil {
return err
}
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -113,15 +185,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false
}
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
if output.IgnoredUsers != nil {
if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil {
log.WithError(err).WithFields(logrus.Fields{
@ -136,34 +199,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return true
}
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct {
topic string
db storage.Database
notifier *notifier.Notifier
stream types.StreamProvider
stream streams.StreamProvider
serverName gomatrixserverlib.ServerName // our server name
rsAPI roomserverAPI.SyncRoomserverAPI
}
@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer(
rsAPI roomserverAPI.SyncRoomserverAPI,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
stream streams.StreamProvider,
) *OutputKeyChangeEventConsumer {
s := &OutputKeyChangeEventConsumer{
ctx: process.Context(),

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -39,7 +40,7 @@ type PresenceConsumer struct {
requestTopic string
presenceTopic string
db storage.Database
stream types.StreamProvider
stream streams.StreamProvider
notifier *notifier.Notifier
deviceAPI api.SyncUserAPI
cfg *config.SyncAPI
@ -54,7 +55,7 @@ func NewPresenceConsumer(
nats *nats.Conn,
db storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
stream streams.StreamProvider,
deviceAPI api.SyncUserAPI,
) *PresenceConsumer {
return &PresenceConsumer{

View file

@ -16,22 +16,20 @@ package consumers
import (
"context"
"database/sql"
"fmt"
"strconv"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
// OutputReceiptEventConsumer consumes events that originated in the EDU server.
@ -41,10 +39,9 @@ type OutputReceiptEventConsumer struct {
durable string
topic string
db storage.Database
stream types.StreamProvider
stream streams.StreamProvider
notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -55,8 +52,7 @@ func NewOutputReceiptEventConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
stream streams.StreamProvider,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
@ -67,7 +63,6 @@ func NewOutputReceiptEventConsumer(
notifier: notifier,
stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
}
}
@ -111,42 +106,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return true
}
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true
}
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -21,17 +21,21 @@ import (
"fmt"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// OutputRoomEventConsumer consumes events that originated in the room server.
@ -43,10 +47,10 @@ type OutputRoomEventConsumer struct {
durable string
topic string
db storage.Database
pduStream types.StreamProvider
inviteStream types.StreamProvider
pduStream streams.StreamProvider
inviteStream streams.StreamProvider
notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
fts *fulltext.Search
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -56,10 +60,10 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
pduStream types.StreamProvider,
inviteStream types.StreamProvider,
pduStream streams.StreamProvider,
inviteStream streams.StreamProvider,
rsAPI api.SyncRoomserverAPI,
producer *producers.UserAPIStreamEventProducer,
fts *fulltext.Search,
) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{
ctx: process.Context(),
@ -72,7 +76,7 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream,
inviteStream: inviteStream,
rsAPI: rsAPI,
producer: producer,
fts: fts,
}
}
@ -254,11 +258,11 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write new event failure")
return nil
}
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
if err = s.writeFTS(ev, pduPos); err != nil {
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"type": ev.Type(),
}).WithError(err).Warn("failed to index fulltext element")
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
@ -304,6 +308,13 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
return nil
}
if err = s.writeFTS(ev, pduPos); err != nil {
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"type": ev.Type(),
}).WithError(err).Warn("failed to index fulltext element")
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err
@ -440,8 +451,15 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
}
stateKey := *event.StateKey()
prevEvent, err := s.db.GetStateEvent(
context.TODO(), event.RoomID(), event.Type(), stateKey,
snapshot, err := s.db.NewDatabaseSnapshot(s.ctx)
if err != nil {
return nil, err
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
prevEvent, err := snapshot.GetStateEvent(
s.ctx, event.RoomID(), event.Type(), stateKey,
)
if err != nil {
return event, err
@ -458,5 +476,42 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
}
event.Event, err = event.SetUnsigned(prev)
succeeded = true
return event, err
}
func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error {
if !s.cfg.Fulltext.Enabled {
return nil
}
e := fulltext.IndexElement{
EventID: ev.EventID(),
RoomID: ev.RoomID(),
StreamPosition: int64(pduPosition),
}
e.SetContentType(ev.Type())
switch ev.Type() {
case "m.room.message":
e.Content = gjson.GetBytes(ev.Content(), "body").String()
case gomatrixserverlib.MRoomName:
e.Content = gjson.GetBytes(ev.Content(), "name").String()
case gomatrixserverlib.MRoomTopic:
e.Content = gjson.GetBytes(ev.Content(), "topic").String()
case gomatrixserverlib.MRoomRedaction:
log.Tracef("Redacting event: %s", ev.Redacts())
if err := s.fts.Delete(ev.Redacts()); err != nil {
return fmt.Errorf("failed to delete entry from fulltext index: %w", err)
}
return nil
default:
return nil
}
if e.Content != "" {
log.Tracef("Indexing element: %+v", e)
if err := s.fts.Index(e); err != nil {
return err
}
}
return nil
}

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct {
db storage.Database
keyAPI keyapi.SyncKeyAPI
serverName gomatrixserverlib.ServerName // our server name
stream types.StreamProvider
stream streams.StreamProvider
notifier *notifier.Notifier
}
@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer(
store storage.Database,
keyAPI keyapi.SyncKeyAPI,
notifier *notifier.Notifier,
stream types.StreamProvider,
stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer {
return &OutputSendToDeviceEventConsumer{
ctx: process.Context(),

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct {
durable string
topic string
eduCache *caching.EDUCache
stream types.StreamProvider
stream streams.StreamProvider
notifier *notifier.Notifier
}
@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer(
js nats.JetStreamContext,
eduCache *caching.EDUCache,
notifier *notifier.Notifier,
stream types.StreamProvider,
stream streams.StreamProvider,
) *OutputTypingEventConsumer {
return &OutputTypingEventConsumer{
ctx: process.Context(),

View file

@ -19,15 +19,17 @@ import (
"encoding/json"
"github.com/getsentry/sentry-go"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// OutputNotificationDataConsumer consumes events that originated in
@ -39,7 +41,7 @@ type OutputNotificationDataConsumer struct {
topic string
db storage.Database
notifier *notifier.Notifier
stream types.StreamProvider
stream streams.StreamProvider
}
// NewOutputNotificationDataConsumer creates a new consumer. Call
@ -50,7 +52,7 @@ func NewOutputNotificationDataConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
stream streams.StreamProvider,
) *OutputNotificationDataConsumer {
s := &OutputNotificationDataConsumer{
ctx: process.Context(),

View file

@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) {
// Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
syncDB storage.Database,
syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},

View file

@ -19,6 +19,7 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -318,18 +319,27 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
n.lock.Lock()
defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
return err
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx)
if err != nil {
return err
}
n.setUsersJoinedToRooms(roomToUsers)
roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx)
roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx)
if err != nil {
return err
}
n.setPeekingDevices(roomToPeekingDevices)
succeeded = true
return nil
}
@ -338,12 +348,20 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [
n.lock.Lock()
defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs)
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
return err
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs)
if err != nil {
return err
}
n.setUsersJoinedToRooms(roomToUsers)
succeeded = true
return nil
}

View file

@ -1,62 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -1,60 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
@ -37,11 +38,11 @@ import (
type ContextRespsonse struct {
End string `json:"end"`
Event gomatrixserverlib.ClientEvent `json:"event"`
Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"`
EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"`
EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"`
Start string `json:"start"`
State []gomatrixserverlib.ClientEvent `json:"state"`
State []gomatrixserverlib.ClientEvent `json:"state,omitempty"`
}
func Context(
@ -51,6 +52,13 @@ func Context(
roomID, eventID string,
lazyLoadCache caching.LazyLoadCache,
) util.JSONResponse {
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
filter, err := parseRoomEventFilter(req)
if err != nil {
errMsg := ""
@ -97,7 +105,7 @@ func Context(
ContainsURL: filter.ContainsURL,
}
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID)
if err != nil {
if err == sql.ErrNoRows {
return util.JSONResponse{
@ -111,7 +119,7 @@ func Context(
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
@ -127,20 +135,20 @@ func Context(
}
}
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events")
return jsonerror.InternalServerError()
}
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter)
_, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch after events")
return jsonerror.InternalServerError()
}
startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
@ -152,7 +160,7 @@ func Context(
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError()
@ -162,8 +170,9 @@ func Context(
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll)
response := ContextRespsonse{
Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll),
Event: &ev,
EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient,
State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll),
@ -172,11 +181,12 @@ func Context(
if len(response.State) > filter.Limit {
response.State = response.State[len(response.State)-filter.Limit:]
}
start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err == nil {
response.End = end.String()
response.Start = start.String()
}
succeeded = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
@ -187,7 +197,7 @@ func Context(
// by combining the events before and after the context event. Returns the filtered events,
// and an error, if any.
func applyHistoryVisibilityOnContextEvents(
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
@ -204,7 +214,7 @@ func applyHistoryVisibilityOnContextEvents(
}
allEvents := append(eventsBefore, eventsAfter...)
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context")
if err != nil {
return nil, nil, err
}
@ -221,15 +231,15 @@ func applyHistoryVisibilityOnContextEvents(
return filteredBefore, filteredAfter, nil
}
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())
start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID())
if err != nil {
return
}
}
if len(endEvents) > 0 {
end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID())
end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID())
}
return
}

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
@ -39,6 +40,7 @@ import (
type messagesReq struct {
ctx context.Context
db storage.Database
snapshot storage.DatabaseTransaction
rsAPI api.SyncRoomserverAPI
cfg *config.SyncAPI
roomID string
@ -70,6 +72,16 @@ func OnIncomingMessagesRequest(
) util.JSONResponse {
var err error
// NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
// expect to be able to write to the database in response to a /messages
// request that requires backfilling from the roomserver or federation.
snapshot, err := db.NewDatabaseTransaction(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
// check if the user has already forgotten about this room
isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
if err != nil {
@ -132,7 +144,7 @@ func OnIncomingMessagesRequest(
}
} else {
fromStream = &streamToken
from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
@ -154,7 +166,7 @@ func OnIncomingMessagesRequest(
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
}
} else {
to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
@ -165,7 +177,7 @@ func OnIncomingMessagesRequest(
// If "to" isn't provided, it defaults to either the earliest stream
// position (if we're going backward) or to the latest one (if we're
// going forward).
to, err = setToDefault(req.Context(), db, backwardOrdering, roomID)
to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
return jsonerror.InternalServerError()
@ -186,6 +198,7 @@ func OnIncomingMessagesRequest(
mReq := messagesReq{
ctx: req.Context(),
db: db,
snapshot: snapshot,
rsAPI: rsAPI,
cfg: cfg,
roomID: roomID,
@ -217,7 +230,7 @@ func OnIncomingMessagesRequest(
Start: start.String(),
End: end.String(),
}
res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
// If we didn't return any events, set the end to an empty string, so it will be omitted
// in the response JSON.
@ -229,6 +242,7 @@ func OnIncomingMessagesRequest(
}
// Respond with the events.
succeeded = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
@ -239,7 +253,7 @@ func OnIncomingMessagesRequest(
// LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context,
db storage.Database,
db storage.DatabaseTransaction,
roomID string,
device *userapi.Device,
lazyLoad bool,
@ -292,7 +306,7 @@ func (r *messagesReq) retrieveEvents() (
end types.TopologyToken, err error,
) {
// Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
return
@ -348,7 +362,7 @@ func (r *messagesReq) retrieveEvents() (
// Apply room history visibility filter
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": r.roomID,
@ -366,7 +380,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
// else to go. This seems to fix Element iOS from looping on /messages endlessly.
end = types.TopologyToken{}
} else {
end, err = r.db.EventPositionInTopology(
end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[0].EventID(),
)
// A stream/topological position is a cursor located between two events.
@ -378,7 +392,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
}
} else {
start = *r.from
end, err = r.db.EventPositionInTopology(
end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
}
@ -399,7 +413,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
func (r *messagesReq) handleEmptyEventsSlice() (
events []*gomatrixserverlib.HeaderedEvent, err error,
) {
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
// Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 {
@ -443,7 +457,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
}
// Check if the slice contains a backward extremity.
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
if err != nil {
return
}
@ -463,7 +477,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
}
// Append the events ve previously retrieved locally.
events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...)
events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...)
sort.Sort(eventsByDepth(events))
return
@ -553,7 +567,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// Returns an error if there was an issue with retrieving the latest position
// from the database
func setToDefault(
ctx context.Context, db storage.Database, backwardOrdering bool,
ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
roomID string,
) (to types.TopologyToken, err error) {
if backwardOrdering {
@ -561,7 +575,7 @@ func setToDefault(
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.TopologyToken{}
} else {
to, err = db.MaxTopologicalPosition(ctx, roomID)
to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
}
return

View file

@ -18,15 +18,18 @@ import (
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// Setup configures the given mux with sync-server listeners
@ -40,6 +43,7 @@ func Setup(
rsAPI api.SyncRoomserverAPI,
cfg *config.SyncAPI,
lazyLoadCache caching.LazyLoadCache,
fts *fulltext.Search,
) {
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
@ -95,4 +99,24 @@ func Setup(
)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/search",
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if !cfg.Fulltext.Enabled {
return util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.Unknown("Search has been disabled by the server administrator."),
}
}
var nextBatch *string
if err := req.ParseForm(); err != nil {
return jsonerror.InternalServerError()
}
if req.Form.Has("next_batch") {
nb := req.FormValue("next_batch")
nextBatch = &nb
}
return Search(req, device, syncDB, fts, nextBatch)
}),
).Methods(http.MethodPost, http.MethodOptions)
}

353
syncapi/routing/search.go Normal file
View file

@ -0,0 +1,353 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/blevesearch/bleve/v2/search"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/userapi/api"
)
// nolint:gocyclo
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse {
start := time.Now()
var (
searchReq SearchRequest
err error
ctx = req.Context()
)
resErr := httputil.UnmarshalJSONRequest(req, &searchReq)
if resErr != nil {
logrus.Error("failed to unmarshal search request")
return *resErr
}
nextBatch := 0
if from != nil && *from != "" {
nextBatch, err = strconv.Atoi(*from)
if err != nil {
return jsonerror.InternalServerError()
}
}
if searchReq.SearchCategories.RoomEvents.Filter.Limit == 0 {
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
}
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
// only search rooms the user is actually joined to
joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join")
if err != nil {
return jsonerror.InternalServerError()
}
if len(joinedRooms) == 0 {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("User not joined to any rooms."),
}
}
joinedRoomsMap := make(map[string]struct{}, len(joinedRooms))
for _, roomID := range joinedRooms {
joinedRoomsMap[roomID] = struct{}{}
}
rooms := []string{}
if searchReq.SearchCategories.RoomEvents.Filter.Rooms != nil {
for _, roomID := range *searchReq.SearchCategories.RoomEvents.Filter.Rooms {
if _, ok := joinedRoomsMap[roomID]; ok {
rooms = append(rooms, roomID)
}
}
} else {
rooms = joinedRooms
}
if len(rooms) == 0 {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Unknown("User not allowed to search in this room(s)."),
}
}
orderByTime := searchReq.SearchCategories.RoomEvents.OrderBy == "recent"
result, err := fts.Search(
searchReq.SearchCategories.RoomEvents.SearchTerm,
rooms,
searchReq.SearchCategories.RoomEvents.Keys,
searchReq.SearchCategories.RoomEvents.Filter.Limit,
nextBatch,
orderByTime,
)
if err != nil {
logrus.WithError(err).Error("failed to search fulltext")
return jsonerror.InternalServerError()
}
logrus.Debugf("Search took %s", result.Took)
// From was specified but empty, return no results, only the count
if from != nil && *from == "" {
return util.JSONResponse{
Code: http.StatusOK,
JSON: SearchResponse{
SearchCategories: SearchCategories{
RoomEvents: RoomEvents{
Count: int(result.Total),
NextBatch: nil,
},
},
},
}
}
results := []Result{}
wantEvents := make([]string, 0, len(result.Hits))
eventScore := make(map[string]*search.DocumentMatch)
for _, hit := range result.Hits {
wantEvents = append(wantEvents, hit.ID)
eventScore[hit.ID] = hit
}
// Filter on m.room.message, as otherwise we also get events like m.reaction
// which "breaks" displaying results in Element Web.
types := []string{"m.room.message"}
roomFilter := &gomatrixserverlib.RoomEventFilter{
Rooms: &rooms,
Types: &types,
}
evs, err := syncDB.Events(ctx, wantEvents)
if err != nil {
logrus.WithError(err).Error("failed to get events from database")
return jsonerror.InternalServerError()
}
groups := make(map[string]RoomResult)
knownUsersProfiles := make(map[string]ProfileInfo)
// Sort the events by depth, as the returned values aren't ordered
if orderByTime {
sort.Slice(evs, func(i, j int) bool {
return evs[i].Depth() > evs[j].Depth()
})
}
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
for _, event := range evs {
eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq)
if err != nil {
logrus.WithError(err).Error("failed to get context events")
return jsonerror.InternalServerError()
}
startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err != nil {
logrus.WithError(err).Error("failed to get start/end")
return jsonerror.InternalServerError()
}
profileInfos := make(map[string]ProfileInfo)
for _, ev := range append(eventsBefore, eventsAfter...) {
profile, ok := knownUsersProfiles[event.Sender()]
if !ok {
stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
if err != nil {
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
continue
}
if stateEvent == nil {
continue
}
profile = ProfileInfo{
AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str,
DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str,
}
knownUsersProfiles[event.Sender()] = profile
}
profileInfos[ev.Sender()] = profile
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
End: endToken.String(),
EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync),
EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync),
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
groups[event.RoomID()] = roomGroup
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
stateFilter := gomatrixserverlib.DefaultStateFilter()
state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to get current state")
return jsonerror.InternalServerError()
}
stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync)
}
}
var nextBatchResult *string = nil
if int(result.Total) > nextBatch+len(results) {
nb := strconv.Itoa(len(results) + nextBatch)
nextBatchResult = &nb
} else if int(result.Total) == nextBatch+len(results) {
// Sytest expects a next_batch even if we don't actually have any more results
nb := ""
nextBatchResult = &nb
}
res := SearchResponse{
SearchCategories: SearchCategories{
RoomEvents: RoomEvents{
Count: int(result.Total),
Groups: Groups{RoomID: groups},
Results: results,
NextBatch: nextBatchResult,
Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "),
State: stateForRooms,
},
},
}
logrus.Debugf("Full search request took %v", time.Since(start))
succeeded = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
}
}
// contextEvents returns the events around a given eventID
func contextEvents(
ctx context.Context,
snapshot storage.DatabaseTransaction,
event *gomatrixserverlib.HeaderedEvent,
roomFilter *gomatrixserverlib.RoomEventFilter,
searchReq SearchRequest,
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID())
if err != nil {
logrus.WithError(err).Error("failed to query context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query before context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
_, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query after context event")
return nil, nil, err
}
return eventsBefore, eventsAfter, err
}
type SearchRequest struct {
SearchCategories struct {
RoomEvents struct {
EventContext struct {
AfterLimit int `json:"after_limit,omitempty"`
BeforeLimit int `json:"before_limit,omitempty"`
IncludeProfile bool `json:"include_profile,omitempty"`
} `json:"event_context"`
Filter gomatrixserverlib.StateFilter `json:"filter"`
Groupings struct {
GroupBy []struct {
Key string `json:"key"`
} `json:"group_by"`
} `json:"groupings"`
IncludeState bool `json:"include_state"`
Keys []string `json:"keys"`
OrderBy string `json:"order_by"`
SearchTerm string `json:"search_term"`
} `json:"room_events"`
} `json:"search_categories"`
}
type SearchResponse struct {
SearchCategories SearchCategories `json:"search_categories"`
}
type RoomResult struct {
NextBatch *string `json:"next_batch,omitempty"`
Order int `json:"order"`
Results []string `json:"results"`
}
type Groups struct {
RoomID map[string]RoomResult `json:"room_id"`
}
type Result struct {
Context SearchContextResponse `json:"context"`
Rank float64 `json:"rank"`
Result gomatrixserverlib.ClientEvent `json:"result"`
}
type SearchContextResponse struct {
End string `json:"end"`
EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"`
EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"`
Start string `json:"start"`
ProfileInfo map[string]ProfileInfo `json:"profile_info"`
}
type ProfileInfo struct {
AvatarURL string `json:"avatar_url"`
DisplayName string `json:"display_name"`
}
type RoomEvents struct {
Count int `json:"count"`
Groups Groups `json:"groups"`
Highlights []string `json:"highlights"`
NextBatch *string `json:"next_batch,omitempty"`
Results []Result `json:"results"`
State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"`
}
type SearchCategories struct {
RoomEvents RoomEvents `json:"room_events"`
}

View file

@ -17,17 +17,18 @@ package storage
import (
"context"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
Presence
type DatabaseTransaction interface {
sqlutil.Transaction
SharedUsers
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
@ -36,6 +37,7 @@ type Database interface {
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -43,23 +45,77 @@ type Database interface {
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
// GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
}
type Database interface {
Presence
Notifications
NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error)
NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error)
// Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list.
@ -76,20 +132,6 @@ type Database interface {
// PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state.
PurgeRoomState(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
// GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)
@ -113,21 +155,6 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
@ -145,37 +172,21 @@ type Database interface {
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error
// StoreReceipt stores new receipt events
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
}
type Presence interface {
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
}
type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
}
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
}

View file

@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData(
}
func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
userID string,
r types.Range,
accountDataEventFilter *gomatrixserverlib.EventFilter,
) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext(
ctx, userID, r.Low(), r.High(),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)),
accountDataEventFilter.Limit,

View file

@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
if err != nil {
return
}

View file

@ -104,12 +104,7 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because
// the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
"SELECT event_id, added_at, headered_event_json, history_visibility FROM syncapi_current_room_state WHERE event_id = ANY($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
@ -185,9 +180,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil {
return nil, err
}
@ -209,9 +204,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs))
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
@ -365,7 +360,36 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed")
return rowsToStreamEvents(rows)
return currentRoomStateRowsToStreamEvents(rows)
}
func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var events []types.StreamEvent
for rows.Next() {
var (
eventID string
streamPos types.StreamPosition
eventBytes []byte
historyVisibility gomatrixserverlib.HistoryVisibility
)
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &historyVisibility); err != nil {
return nil, err
}
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
ev.Visibility = historyVisibility
events = append(events, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
})
}
return events, nil
}
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
@ -387,9 +411,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
}
func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string,
ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt
stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows {

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
}
func (s *filterStatements) SelectFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return err
}
@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter(
}
func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
ctx, localpart, filterJSON,
).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter(
}
// Otherwise insert the filter and return the new ID
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart).
Scan(&filterID)
return
}

View file

@ -55,7 +55,7 @@ const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
"SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
return
}
err = s.insertInviteEventStmt.QueryRowContext(
err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext(
ctx,
inviteEvent.RoomID(),
inviteEvent.EventID(),
@ -121,23 +121,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
var lastPos types.StreamPosition
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil {
return nil, nil, err
return nil, nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]*gomatrixserverlib.HeaderedEvent{}
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
var (
id types.StreamPosition
roomID string
eventJSON []byte
deleted bool
)
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
return nil, nil, err
if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
return nil, nil, lastPos, err
}
if id > lastPos {
lastPos = id
}
// if we have seen this room before, it has a higher stream position and hence takes priority
@ -150,7 +155,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
var event *gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, nil, err
return nil, nil, lastPos, err
}
if deleted {
@ -159,7 +164,10 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
result[roomID] = event
}
}
return result, retired, rows.Err()
if lastPos == 0 {
lastPos = r.To
}
return result, retired, lastPos, rows.Err()
}
func (s *inviteEventsStatements) SelectMaxInviteID(

View file

@ -18,6 +18,8 @@ import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{}
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db)
}
type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt
}
const notificationDataSchema = `
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE user_id = $1 AND
room_id = ANY($2)`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return
}
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() {
var id types.StreamPosition
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err
}

View file

@ -166,6 +166,8 @@ const selectContextAfterEventSQL = "" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id ASC LIMIT $3"
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3"
type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
@ -180,6 +182,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
selectSearchStmt *sql.Stmt
}
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -215,15 +218,16 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
{&s.selectSearchStmt, selectSearchSQL},
}.Prepare(db)
}
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
_, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
return err
}
@ -632,3 +636,27 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
}
return result, rows.Err()
}
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed")
var eventID string
var id int64
result := make(map[int64]gomatrixserverlib.HeaderedEvent)
for rows.Next() {
var ev gomatrixserverlib.HeaderedEvent
var eventBytes []byte
if err = rows.Scan(&id, &eventID, &eventBytes); err != nil {
return nil, err
}
if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
result[id] = ev
}
return result, rows.Err()
}

View file

@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
ctx context.Context, txn *sql.Tx, eventID string,
) (pos, spos types.StreamPosition, err error) {
err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos)
return
}
@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) {
if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
}
return
}
@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}

View file

@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange(
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
if err != nil {
return nil, err
}

View file

@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return
}
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
var lastPos types.StreamPosition
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}

View file

@ -0,0 +1,588 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
// For now this contains the shared functions
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
Invites tables.Invites
Peeks tables.Peeks
AccountData tables.AccountData
OutputEvents tables.Events
Topology tables.Topology
CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
Filter tables.Filter
Receipts tables.Receipts
Memberships tables.Memberships
NotificationData tables.NotificationData
Ignores tables.Ignores
Presence tables.Presence
}
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
return d.NewDatabaseTransaction(ctx)
/*
TODO: Repeatable read is probably the right thing to do here,
but it seems to cause some problems with the invite tests, so
need to investigate that further.
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
// Set the isolation level so that we see a snapshot of the database.
// In PostgreSQL repeatable read transactions will see a snapshot taken
// at the first query, and since the transaction is read-only it can't
// run into any serialisation errors.
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
})
if err != nil {
return nil, err
}
return &DatabaseTransaction{
Database: d,
ctx: ctx,
txn: txn,
}, nil
*/
}
func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
txn, err := d.DB.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &DatabaseTransaction{
Database: d,
ctx: ctx,
txn: txn,
}, nil
}
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil {
return nil, err
}
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil
}
// AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
return err
})
return
}
// RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string,
) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
return err
})
return
}
// AddPeek tracks the fact that a user has started peeking.
// If the peek was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) AddPeek(
ctx context.Context, roomID, userID, deviceID string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
return err
})
return
}
// DeletePeek tracks the fact that a user has stopped peeking from the specified
// device. If the peeks was successfully deleted this returns the stream ID it was
// stored at. Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeek(
ctx context.Context, roomID, userID, deviceID string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID)
return err
})
if err == sql.ErrNoRows {
sp = 0
err = nil
}
return
}
// DeletePeeks tracks the fact that a user has stopped peeking from all devices
// If the peeks was successfully deleted this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeeks(
ctx context.Context, roomID, userID string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
return err
})
if err == sql.ErrNoRows {
sp = 0
err = nil
}
return
}
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)
// If no data with the given type, user ID and room ID exists in the database,
// creates a new row, else update the existing one
// Returns an error if there was an issue with the upsert
func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err
})
return
}
func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
}
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
return err
}
// Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
if err != nil {
return err
}
var found bool
for _, eID := range ev.PrevEventIDs() {
found = false
for _, prevEv := range prevEvents {
if eID == prevEv.EventID() {
found = true
}
}
// If the event is missing, consider it a backward extremity.
if !found {
if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil {
return err
}
}
}
return nil
}
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
}
return nil
})
}
func (d *Database) WriteEvent(
ctx context.Context,
ev *gomatrixserverlib.HeaderedEvent,
addStateEvents []*gomatrixserverlib.HeaderedEvent,
addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool,
historyVisibility gomatrixserverlib.HistoryVisibility,
) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
ev.Visibility = historyVisibility
pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
)
if err != nil {
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
}
pduPosition = pos
var topoPosition types.StreamPosition
if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
}
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
return fmt.Errorf("d.handleBackwardExtremities: %w", err)
}
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event.
return nil
}
for i := range addStateEvents {
addStateEvents[i].Visibility = historyVisibility
}
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
})
return pduPosition, returnErr
}
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) updateRoomState(
ctx context.Context, txn *sql.Tx,
removedEventIDs []string,
addedEvents []*gomatrixserverlib.HeaderedEvent,
pduPosition types.StreamPosition,
topoPosition types.StreamPosition,
) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
}
}
for _, event := range addedEvents {
if event.StateKey() == nil {
// ignore non state events
continue
}
var membership *string
if event.Type() == "m.room.member" {
value, err := event.Membership()
if err != nil {
return fmt.Errorf("event.Membership: %w", err)
}
membership = &value
if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil {
return fmt.Errorf("d.Memberships.UpsertMembership: %w", err)
}
}
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
}
}
return nil
}
func (d *Database) GetFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error {
return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID)
}
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
var filterID string
var err error
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart)
return err
})
return filterID, err
}
func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error {
redactedEvents, err := d.Events(ctx, []string{redactedEventID})
if err != nil {
return err
}
if len(redactedEvents) == 0 {
logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
return nil
}
eventToRedact := redactedEvents[0].Unwrap()
redactionEvent := redactedBecause.Unwrap()
if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
return err
}
newEvent := eventToRedact.Headered(redactedBecause.RoomVersion)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
})
return err
}
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events.
func (d *Database) fetchStateEvents(
ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool,
eventIDToEvent map[string]types.StreamEvent,
) (map[string][]types.StreamEvent, error) {
stateBetween := make(map[string][]types.StreamEvent)
missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID]
for id, need := range ids {
if !need {
continue // deleted state
}
e, ok := eventIDToEvent[id]
if ok {
events = append(events, e)
} else {
m := missingEvents[roomID]
m = append(m, id)
missingEvents[roomID] = m
}
}
stateBetween[roomID] = events
}
if len(missingEvents) > 0 {
// This happens when add_state_ids has an event ID which is not in the provided range.
// We need to explicitly fetch them.
allMissingEventIDs := []string{}
for _, missingEvIDs := range missingEvents {
allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
}
evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
if err != nil {
return nil, err
}
// we know we got them all otherwise an error would've been returned, so just loop the events
for _, ev := range evs {
roomID := ev.RoomID()
stateBetween[roomID] = append(stateBetween[roomID], ev)
}
}
return stateBetween, nil
}
func (d *Database) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
if err != nil {
return nil, err
}
have := map[string]bool{}
for _, event := range events {
have[event.EventID()] = true
}
var missing []string
for _, eventID := range eventIDs {
if !have[eventID] {
missing = append(missing, eventID)
}
}
if len(missing) == 0 {
return events, nil
}
// If they are missing from the events table then they should be state
// events that we received from outside the main event stream.
// These should be in the room state table.
stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
if err != nil {
return nil, err
}
if len(stateEvents) != len(missing) {
logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
// TODO: Why is this happening? It's probably the roomserver. Uncomment
// this error again when we work out what it is and fix it, otherwise we
// just end up returning lots of 500s to the client and that breaks
// pretty much everything, rather than just sending what we have.
//return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
}
events = append(events, stateEvents...)
return events, nil
}
func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (newPos types.StreamPosition, err error) {
j, err := json.Marshal(event)
if err != nil {
return 0, err
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, string(j),
)
return err
})
if err != nil {
return 0, err
}
return newPos, nil
}
func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context,
userID, deviceID string, before types.StreamPosition,
) (err error) {
if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
}); err != nil {
logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
return err
}
return nil
}
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
return "", ""
}
membership, err := ev.Membership()
if err != nil {
return "", ""
}
prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
return membership, prevMembership
}
// StoreReceipt stores user receipts
func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp)
return err
})
return
}
func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount)
return err
})
return
}
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
}
func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
}
func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
}
func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return d.Ignores.SelectIgnores(ctx, nil, userID)
}
func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores)
})
}
func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
var pos types.StreamPosition
var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync)
return nil
})
return pos, err
}
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
}
func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
gomatrixserverlib.MRoomName,
gomatrixserverlib.MRoomTopic,
"m.room.message",
})
}

View file

@ -0,0 +1,575 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type DatabaseTransaction struct {
*Database
ctx context.Context
txn *sql.Tx
}
func (d *DatabaseTransaction) Commit() error {
if d.txn == nil {
return nil
}
return d.txn.Commit()
}
func (d *DatabaseTransaction) Rollback() error {
if d.txn == nil {
return nil
}
return d.txn.Rollback()
}
func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Invites.SelectMaxInviteID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.NotificationData.SelectMaxID(ctx, d.txn)
if err != nil {
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs)
}
func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership)
}
func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
}
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
}
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
}
func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
}
func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r)
}
func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r)
}
func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
}
// Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false)
if err != nil {
return nil, err
}
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil
}
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn)
}
func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs)
}
func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx, d.txn)
}
func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs)
}
func (d *DatabaseTransaction) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey)
}
func (d *DatabaseTransaction) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
return
}
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
func (d *DatabaseTransaction) GetAccountDataInRange(
ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, types.StreamPosition, error) {
return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart)
}
func (d *DatabaseTransaction) GetEventsInTopologicalRange(
ctx context.Context,
from, to *types.TopologyToken,
roomID string,
filter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
if backwardOrdering {
// Backward ordering means the 'from' token has a higher depth than the 'to' token
minDepth = to.Depth
maxDepth = from.Depth
// for cases where we have say 5 events with the same depth, the TopologyToken needs to
// know which of the 5 the client has seen. This is done by using the PDU position.
// Events with the same maxDepth but less than this PDU position will be returned.
maxStreamPosForMaxDepth = from.PDUPosition
} else {
// Forward ordering means the 'from' token has a lower depth than the 'to' token.
minDepth = from.Depth
maxDepth = to.Depth
}
// Select the event IDs from the defined range.
var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
)
if err != nil {
return
}
// Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
return
}
func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
ctx context.Context, roomID string,
) (backwardExtremities map[string][]string, err error) {
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
}
func (d *DatabaseTransaction) MaxTopologicalPosition(
ctx context.Context, roomID string,
) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
if err != nil {
return types.TopologyToken{}, err
}
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
}
func (d *DatabaseTransaction) EventPositionInTopology(
ctx context.Context, eventID string,
) (types.TopologyToken, error) {
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
if err != nil {
return types.TopologyToken{}, err
}
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
}
func (d *DatabaseTransaction) StreamToTopologicalPosition(
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (types.TopologyToken, error) {
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering)
switch {
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
return types.TopologyToken{PDUPosition: streamPos}, nil
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
if err != nil {
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
}
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
case err != nil: // some other error happened
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
default:
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
}
}
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology.
func (d *DatabaseTransaction) GetBackwardTopologyPos(
ctx context.Context,
events []types.StreamEvent,
) (types.TopologyToken, error) {
zeroToken := types.TopologyToken{}
if len(events) == 0 {
return zeroToken, nil
}
pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID())
if err != nil {
return zeroToken, err
}
tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
tok.Decrement()
return tok, nil
}
// GetStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device,
r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter,
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response
// - For each room which has membership list changes:
// * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
// If it is, then we need to send the full room state down (and 'limited' is always true).
// * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
// * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
// - Get all CURRENTLY joined rooms, and add them to 'joined' block.
// Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
allRoomIDs := make([]string, 0, len(memberships))
joinedRoomIDs := make([]string, 0, len(memberships))
for roomID, membership := range memberships {
allRoomIDs = append(allRoomIDs, roomID)
if membership == gomatrixserverlib.Join {
joinedRoomIDs = append(joinedRoomIDs, roomID)
}
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows {
return nil, nil, err
}
// add peek blocks
for _, peek := range peeks {
if peek.New {
// send full room state down instead of a delta
var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err
}
state[peek.RoomID] = s
}
if !peek.Deleted {
deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Peek,
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
RoomID: peek.RoomID,
})
}
}
// handle newly joined rooms and non-joined rooms
newlyJoinedRooms := make(map[string]bool, len(state))
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join && prevMembership != membership {
// send full room state down instead of a delta
var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter)
if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err
}
state[roomID] = s
newlyJoinedRooms[roomID] = true
continue // we'll add this room in when we do joined rooms
}
deltas = append(deltas, types.StateDelta{
Membership: membership,
MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
RoomID: roomID,
})
break
}
}
}
// Add in currently joined rooms
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID],
})
}
return deltas, joinedRoomIDs, nil
}
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device,
r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StateDelta, []string, error) {
// Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
allRoomIDs := make([]string, 0, len(memberships))
joinedRoomIDs := make([]string, 0, len(memberships))
for roomID, membership := range memberships {
allRoomIDs = append(allRoomIDs, roomID)
if membership == gomatrixserverlib.Join {
joinedRoomIDs = append(joinedRoomIDs, roomID)
}
}
// Use a reasonable initial capacity
deltas := make(map[string]types.StateDelta)
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows {
return nil, nil, err
}
// Add full states for all peeking rooms
for _, peek := range peeks {
if !peek.Deleted {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr
}
deltas[peek.RoomID] = types.StateDelta{
Membership: gomatrixserverlib.Peek,
StateEvents: d.StreamEventsToEvents(device, s),
RoomID: peek.RoomID,
}
}
}
// Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{
Membership: membership,
MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
RoomID: roomID,
}
}
break
}
}
}
// Add full states for all joined rooms
for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter)
if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr
}
deltas[joinedRoomID] = types.StateDelta{
Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, s),
RoomID: joinedRoomID,
}
}
// Create a response array.
result := make([]types.StateDelta, len(deltas))
i := 0
for _, delta := range deltas {
result[i] = delta
i++
}
return result, joinedRoomIDs, nil
}
func (d *DatabaseTransaction) currentStateStreamEventsForRoom(
ctx context.Context, roomID string,
stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StreamEvent, error) {
allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
if err != nil {
return nil, err
}
s := make([]types.StreamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
}
return s, nil
}
func (d *DatabaseTransaction) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,
from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to)
if err != nil {
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}
// If there's nothing to do then stop here.
if len(events) == 0 {
return to, nil, nil
}
return lastPos, events, nil
}
func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
return receipts, err
}
func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
}
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
}
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter)
}
func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return d.Presence.GetMaxPresenceID(ctx, d.txn)
}

File diff suppressed because it is too large Load diff

View file

@ -91,14 +91,14 @@ func (s *accountDataStatements) InsertAccountData(
}
func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
userID string,
r types.Range,
filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
s.db, txn, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},

View file

@ -82,9 +82,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
if err != nil {
return
}

View file

@ -88,12 +88,7 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because
// the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
"SELECT event_id, added_at, headered_event_json, history_visibility FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
@ -163,9 +158,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil {
return nil, err
}
@ -187,7 +182,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) {
query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
params := make([]interface{}, 0, len(roomIDs))
@ -200,7 +195,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
@ -367,12 +362,18 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
for start < len(eventIDs) {
n := minOfInts(len(eventIDs)-start, 999)
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
var rows *sql.Rows
var err error
if txn == nil {
rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...)
} else {
rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
}
if err != nil {
return nil, err
}
start = start + n
events, err := rowsToStreamEvents(rows)
events, err := currentRoomStateRowsToStreamEvents(rows)
internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed")
if err != nil {
return nil, err
@ -382,6 +383,35 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
return res, nil
}
func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var events []types.StreamEvent
for rows.Next() {
var (
eventID string
streamPos types.StreamPosition
eventBytes []byte
historyVisibility gomatrixserverlib.HistoryVisibility
)
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &historyVisibility); err != nil {
return nil, err
}
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
ev.Visibility = historyVisibility
events = append(events, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
})
}
return events, nil
}
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
@ -401,9 +431,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
}
func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string,
ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt
stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows {
@ -429,10 +459,17 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
params[k+1] = v
}
var provider sqlutil.QueryProvider
if txn == nil {
provider = s.db
} else {
provider = txn
}
result := make([]string, 0, len(otherUserIDs))
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
err := sqlutil.RunLimitedVariablesQuery(
ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
ctx, query, provider, params, sqlutil.SQLite3MaxVariables,
func(rows *sql.Rows) error {
var stateKey string
for rows.Next() {

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -77,11 +78,11 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
}
func (s *filterStatements) SelectFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return err
}
@ -94,7 +95,7 @@ func (s *filterStatements) SelectFilter(
}
func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
@ -115,8 +116,9 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
ctx, localpart, filterJSON,
).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
@ -126,7 +128,7 @@ func (s *filterStatements) InsertFilter(
}
// Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
res, err := sqlutil.TxStmt(txn, s.insertFilterStmt).ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err
}

View file

@ -50,7 +50,7 @@ const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
"SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
@ -132,23 +132,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
var lastPos types.StreamPosition
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil {
return nil, nil, err
return nil, nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]*gomatrixserverlib.HeaderedEvent{}
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
var (
id types.StreamPosition
roomID string
eventJSON []byte
deleted bool
)
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
return nil, nil, err
if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
return nil, nil, lastPos, err
}
if id > lastPos {
lastPos = id
}
// if we have seen this room before, it has a higher stream position and hence takes priority
@ -161,15 +166,19 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
var event *gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, nil, err
return nil, nil, lastPos, err
}
if deleted {
retired[roomID] = event
} else {
result[roomID] = event
}
}
return result, retired, nil
if lastPos == 0 {
lastPos = r.To
}
return result, retired, lastPos, nil
}
func (s *inviteEventsStatements) SelectMaxInviteID(

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
}
r := &notificationDataStatements{
streamIDStatements: streamID,
db: db,
}
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db)
}
type notificationDataStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
//selectUserUnreadCountsForRooms *sql.Stmt
}
const notificationDataSchema = `
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE user_id = $1 AND
room_id IN ($2)`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -81,20 +82,31 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return
}
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
prep, err := r.db.PrepareContext(ctx, sql)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, prep, "SelectUserUnreadCountsForRooms: prep.close() failed")
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() {
var id types.StreamPosition
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err
}

View file

@ -115,6 +115,8 @@ const selectContextAfterEventSQL = "" +
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC"
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@ -125,6 +127,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
//selectSearchStmt *sql.Stmt - prepared at runtime
}
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
@ -157,15 +160,16 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
{&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
//{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime
}.Prepare(db)
}
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
_, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
return err
}
@ -628,3 +632,40 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [
}
return
}
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
params := make([]interface{}, len(types))
for i := range types {
params[i] = types[i]
}
params = append(params, afterID)
params = append(params, limit)
selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1)
stmt, err := s.db.Prepare(selectSQL)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed")
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed")
var eventID string
var id int64
result := make(map[int64]gomatrixserverlib.HeaderedEvent)
for rows.Next() {
var ev gomatrixserverlib.HeaderedEvent
var eventBytes []byte
if err = rows.Scan(&id, &eventID, &eventBytes); err != nil {
return nil, err
}
if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
result[id] = ev
}
return result, rows.Err()
}

View file

@ -176,9 +176,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) {
if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
}
return
}

View file

@ -172,9 +172,9 @@ func (s *peekStatements) SelectPeeksInRange(
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
if err != nil {
return nil, err
}

View file

@ -108,7 +108,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
}
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
var lastPos types.StreamPosition
params := make([]interface{}, len(roomIDs)+1)
@ -116,7 +116,12 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
for k, v := range roomIDs {
params[k+1] = v
}
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
prep, err := r.db.Prepare(selectSQL)
if err != nil {
return 0, nil, fmt.Errorf("unable to prepare statement: %w", err)
}
defer internal.CloseAndLogIfError(ctx, prep, "SelectRoomReceiptsAfter: prep.close() failed")
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}

View file

@ -49,6 +49,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return &d, nil
}
func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) {
return &shared.DatabaseTransaction{
Database: &d.Database,
// not setting a transaction because SQLite doesn't support it
}, nil
}
func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) {
return &shared.DatabaseTransaction{
Database: &d.Database,
// not setting a transaction because SQLite doesn't support it
}, nil
}
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil {
return err

View file

@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) {
})
}
func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) {
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
t.Fatal(err)
}
f(snapshot)
if err := snapshot.Rollback(); err != nil {
t.Fatal(err)
}
}
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -79,10 +90,13 @@ func TestRecentEventsPDU(t *testing.T) {
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
}
var latest types.StreamPosition
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error
if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
}
})
testCases := []struct {
Name string
@ -140,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
var gotEvents []types.StreamEvent
var limited bool
filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From,
To: tc.To,
}, &filter, !tc.ReverseOrder, true)
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
From: tc.From,
To: tc.To,
}, &filter, !tc.ReverseOrder, true)
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
})
if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
}
@ -178,22 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
events := r.Events()
_ = MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, r.ID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := db.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
// backpaginate 5 messages starting at the latest position.
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
})
})
}
@ -414,13 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) {
defer closeBase()
// At this point there should be no messages. We haven't sent anything
// yet.
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("first call should have no updates")
}
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("first call should have no updates")
}
})
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
@ -432,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) {
t.Fatal(err)
}
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if count := len(events); count != 1 {
t.Fatalf("second call should have one update, got %d", count)
}
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
var events []types.SendToDeviceEvent
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if count := len(events); count != 1 {
t.Fatalf("second call should have one update, got %d", count)
}
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 {
t.Fatal("third call should have one update still")
}
})
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("fourth call should have no updates")
}
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
var events []types.SendToDeviceEvent
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("fourth call should have no updates")
}
// At this point we should still have no updates, because no new updates have been
// sent.
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("fifth call should have no updates")
}
// At this point we should still have no updates, because no new updates have been
// sent.
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 {
t.Fatal("fifth call should have no updates")
}
})
// Send some more messages and verify the ordering is correct ("in order of arrival")
var lastPos types.StreamPosition = 0
@ -492,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
lastPos = streamPos
}
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil {
t.Fatalf("unable to get events: %v", err)
}
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil {
t.Fatalf("unable to get events: %v", err)
}
}
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
}
}
})
})
}

View file

@ -0,0 +1,88 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.CurrentRoomState
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresCurrentRoomStateTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteCurrentRoomStateTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestCurrentRoomStateTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newCurrentRoomStateTable(t, dbType)
defer close()
events := room.CurrentState()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for i, ev := range events {
err := tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i))
if err != nil {
return fmt.Errorf("failed to UpsertRoomState: %w", err)
}
}
wantEventIDs := []string{
events[0].EventID(), events[1].EventID(), events[2].EventID(), events[3].EventID(),
}
gotEvents, err := tab.SelectEventsWithEventIDs(ctx, txn, wantEventIDs)
if err != nil {
return fmt.Errorf("failed to SelectEventsWithEventIDs: %w", err)
}
if len(gotEvents) != len(wantEventIDs) {
return fmt.Errorf("SelectEventsWithEventIDs\ngot %d, want %d results", len(gotEvents), len(wantEventIDs))
}
gotEventIDs := make(map[string]struct{}, len(gotEvents))
for _, event := range gotEvents {
if event.ExcludeFromSync {
return fmt.Errorf("SelectEventsWithEventIDs ExcludeFromSync should be false for current room state event %+v", event)
}
gotEventIDs[event.EventID()] = struct{}{}
}
for _, id := range wantEventIDs {
if _, ok := gotEventIDs[id]; !ok {
return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id)
}
}
return nil
})
if err != nil {
t.Fatalf("err: %v", err)
}
})
}

View file

@ -28,7 +28,7 @@ import (
type AccountData interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error)
SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error)
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
@ -37,7 +37,7 @@ type Invites interface {
DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error)
// SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value
// for the room.
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error)
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error)
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
@ -46,7 +46,7 @@ type Peeks interface {
DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error)
DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error)
SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error)
SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error)
SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
@ -68,13 +68,14 @@ type Events interface {
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error)
}
// Topology keeps track of the depths and stream positions for all events.
@ -97,7 +98,7 @@ type Topology interface {
}
type CurrentRoomState interface {
SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error
DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error
@ -109,9 +110,9 @@ type CurrentRoomState interface {
// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user.
SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
SelectJoinedUsers(ctx context.Context, txn *sql.Tx) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
}
@ -141,7 +142,7 @@ type BackwardsExtremities interface {
// InsertsBackwardExtremity inserts a new backwards extremity.
InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error)
// SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids.
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
}
@ -171,13 +172,13 @@ type SendToDevice interface {
}
type Filter interface {
SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
SelectFilter(ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string) error
InsertFilter(ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
}
type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
@ -190,7 +191,7 @@ type Memberships interface {
type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
}

View file

@ -3,23 +3,27 @@ package streams
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type AccountDataStreamProvider struct {
StreamProvider
DefaultStreamProvider
userAPI userapi.SyncUserAPI
}
func (p *AccountDataStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *AccountDataStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForAccountData(context.Background())
id, err := snapshot.MaxStreamPositionForAccountData(ctx)
if err != nil {
panic(err)
}
@ -28,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() {
func (p *AccountDataStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *AccountDataStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@ -43,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
To: to,
}
dataTypes, pos, err := p.DB.GetAccountDataInRange(
dataTypes, pos, err := snapshot.GetAccountDataInRange(
ctx, req.Device.UserID, r, &req.Filter.AccountData,
)
if err != nil {

View file

@ -6,17 +6,19 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type DeviceListStreamProvider struct {
StreamProvider
DefaultStreamProvider
rsAPI api.SyncRoomserverAPI
keyAPI keyapi.SyncKeyAPI
}
func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.LatestPosition(ctx)
@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync(
func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
var err error
to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from

View file

@ -9,20 +9,23 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type InviteStreamProvider struct {
StreamProvider
DefaultStreamProvider
}
func (p *InviteStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *InviteStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForInvites(context.Background())
id, err := snapshot.MaxStreamPositionForInvites(ctx)
if err != nil {
panic(err)
}
@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() {
func (p *InviteStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *InviteStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync(
To: to,
}
invites, retiredInvites, err := p.DB.InviteEventsInRange(
invites, retiredInvites, maxID, err := snapshot.InviteEventsInRange(
ctx, req.Device.UserID, r,
)
if err != nil {
@ -86,5 +91,5 @@ func (p *InviteStreamProvider) IncrementalSync(
}
}
return to
return maxID
}

View file

@ -3,17 +3,23 @@ package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type NotificationDataStreamProvider struct {
StreamProvider
DefaultStreamProvider
}
func (p *NotificationDataStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *NotificationDataStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := snapshot.MaxStreamPositionForNotificationData(ctx)
if err != nil {
panic(err)
}
@ -22,34 +28,39 @@ func (p *NotificationDataStreamProvider) Setup() {
func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
from, _ types.StreamPosition,
) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
// Get the unread notifications for rooms in our join response.
// This is to ensure clients always have an unread notification section
// and can display the correct numbers.
countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from
}
// We're merely decorating existing rooms. Note that the Join map
// values are not pointers.
// We're merely decorating existing rooms.
for roomID, jr := range req.Response.Rooms.Join {
counts := countsByRoom[roomID]
if counts == nil {
continue
}
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount
jr.UnreadNotifications = &types.UnreadNotifications{
HighlightCount: counts.UnreadHighlightCount,
NotificationCount: counts.UnreadNotificationCount,
}
req.Response.Rooms.Join[roomID] = jr
}
return to
return p.LatestPosition(ctx)
}

View file

@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"sort"
"sync"
"time"
"github.com/matrix-org/dendrite/internal/caching"
@ -18,7 +17,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"go.uber.org/atomic"
"github.com/matrix-org/dendrite/syncapi/notifier"
)
@ -33,44 +31,23 @@ const PDU_STREAM_WORKERS = 256
const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8
type PDUStreamProvider struct {
StreamProvider
DefaultStreamProvider
tasks chan func()
workers atomic.Int32
// userID+deviceID -> lazy loading cache
lazyLoadCache caching.LazyLoadCache
rsAPI roomserverAPI.SyncRoomserverAPI
notifier *notifier.Notifier
}
func (p *PDUStreamProvider) worker() {
defer p.workers.Dec()
for {
select {
case f := <-p.tasks:
f()
case <-time.After(time.Second * 10):
return
}
}
}
func (p *PDUStreamProvider) queue(f func()) {
if p.workers.Load() < PDU_STREAM_WORKERS {
p.workers.Inc()
go p.worker()
}
p.tasks <- f
}
func (p *PDUStreamProvider) Setup() {
p.StreamProvider.Setup()
p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE)
func (p *PDUStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForPDUs(context.Background())
id, err := snapshot.MaxStreamPositionForPDUs(ctx)
if err != nil {
panic(err)
}
@ -79,6 +56,7 @@ func (p *PDUStreamProvider) Setup() {
func (p *PDUStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
from := types.StreamPosition(0)
@ -94,7 +72,7 @@ func (p *PDUStreamProvider) CompleteSync(
}
// Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
if err != nil {
req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed")
return from
@ -103,7 +81,7 @@ func (p *PDUStreamProvider) CompleteSync(
stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
@ -117,33 +95,23 @@ func (p *PDUStreamProvider) CompleteSync(
}
// Build up a /sync response. Add joined rooms.
var reqMutex sync.Mutex
var reqWaitGroup sync.WaitGroup
reqWaitGroup.Add(len(joinedRoomIDs))
for _, room := range joinedRoomIDs {
roomID := room
p.queue(func() {
defer reqWaitGroup.Done()
jr, jerr := p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
)
if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
return
for _, roomID := range joinedRoomIDs {
jr, jerr := p.getJoinResponseForCompleteSync(
ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
)
if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return from
}
reqMutex.Lock()
defer reqMutex.Unlock()
req.Response.Rooms.Join[roomID] = *jr
req.Rooms[roomID] = gomatrixserverlib.Join
})
continue
}
req.Response.Rooms.Join[roomID] = *jr
req.Rooms[roomID] = gomatrixserverlib.Join
}
reqWaitGroup.Wait()
// Add peeked rooms.
peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
if err != nil {
req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
return from
@ -152,11 +120,14 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
return from
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return from
}
continue
}
req.Response.Rooms.Peek[peek.RoomID] = *jr
}
@ -167,6 +138,7 @@ func (p *PDUStreamProvider) CompleteSync(
func (p *PDUStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) (newPos types.StreamPosition) {
@ -184,14 +156,14 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline
if req.WantFullState {
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return
return from
}
} else {
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return
return from
}
}
@ -203,7 +175,7 @@ func (p *PDUStreamProvider) IncrementalSync(
return to
}
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
@ -222,9 +194,12 @@ func (p *PDUStreamProvider) IncrementalSync(
}
}
var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return newPos
}
continue
}
// Reset the position, as it is only for the special case of newly joined rooms
if delta.NewlyJoined {
@ -244,6 +219,7 @@ func (p *PDUStreamProvider) IncrementalSync(
// nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context,
snapshot storage.DatabaseTransaction,
device *userapi.Device,
r types.Range,
delta types.StateDelta,
@ -260,7 +236,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// This is all "okay" assuming history_visibility == "shared" which it is by default.
r.To = delta.MembershipPos
}
recentStreamEvents, limited, err := p.DB.RecentEvents(
recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, delta.RoomID, r,
eventFilter, true, true,
)
@ -270,9 +246,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
}
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil {
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
}
@ -291,7 +267,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) {
var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch {
case r.Backwards && pos < latestPosition:
fallthrough
@ -303,7 +279,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter,
ctx, snapshot, delta.RoomID, true, limited, stateFilter,
device, recentEvents, delta.StateEvents,
)
if err != nil && err != sql.ErrNoRows {
@ -320,7 +296,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
// Applies the history visibility rules
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
@ -336,7 +312,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
if hasMembershipChange {
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
}
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
@ -376,7 +352,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
ctx context.Context,
db storage.Database,
snapshot storage.DatabaseTransaction,
rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string,
limit int,
@ -384,7 +360,7 @@ func applyHistoryVisibilityFilter(
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
@ -395,7 +371,7 @@ func applyHistoryVisibilityFilter(
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {
return nil, err
}
@ -408,10 +384,10 @@ func applyHistoryVisibilityFilter(
return events, nil
}
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
@ -439,7 +415,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
}
}
}
heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
if err != nil {
return
}
@ -449,6 +425,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
roomID string,
r types.Range,
stateFilter *gomatrixserverlib.StateFilter,
@ -460,7 +437,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := p.DB.RecentEvents(
recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil {
@ -484,7 +461,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
}
stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
if err != nil {
return
}
@ -494,7 +471,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID())
backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID())
if err != nil {
return
}
@ -505,18 +482,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement()
}
p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From)
p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents
// Only apply history visibility checks if the response is for joined rooms
if !isPeek {
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
@ -530,7 +507,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
if err != nil {
return nil, err
}
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
stateEvents, err = p.lazyLoadMembers(
ctx, snapshot, roomID,
false, limited, stateFilter,
device, recentEvents, stateEvents,
)
@ -549,7 +527,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string,
ctx context.Context, snapshot storage.DatabaseTransaction, roomID string,
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
@ -598,7 +576,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
filter.Limit = stateFilter.Limit
filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)
memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
if err != nil {
return stateEvents, err
}
@ -612,8 +590,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID)
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID)
if err != nil {
if err == sql.ErrNoRows {
return nil

View file

@ -23,20 +23,26 @@ import (
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type PresenceStreamProvider struct {
StreamProvider
DefaultStreamProvider
// cache contains previously sent presence updates to avoid unneeded updates
cache sync.Map
notifier *notifier.Notifier
}
func (p *PresenceStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *PresenceStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForPresence(context.Background())
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := snapshot.MaxStreamPositionForPresence(ctx)
if err != nil {
panic(err)
}
@ -45,18 +51,20 @@ func (p *PresenceStreamProvider) Setup() {
func (p *PresenceStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *PresenceStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
// We pull out a larger number than the filter asks for, since we're filtering out events later
presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
if err != nil {
req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
return from
@ -84,9 +92,10 @@ func (p *PresenceStreamProvider) IncrementalSync(
}
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
if err != nil {
req.Log.WithError(err).Error("unable to query presence for user")
_ = snapshot.Rollback()
return from
}
if len(presences) > req.Filter.Presence.Limit {

View file

@ -4,18 +4,24 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type ReceiptStreamProvider struct {
StreamProvider
DefaultStreamProvider
}
func (p *ReceiptStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *ReceiptStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForReceipts(context.Background())
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := snapshot.MaxStreamPositionForReceipts(ctx)
if err != nil {
panic(err)
}
@ -24,13 +30,15 @@ func (p *ReceiptStreamProvider) Setup() {
func (p *ReceiptStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *ReceiptStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@ -41,7 +49,7 @@ func (p *ReceiptStreamProvider) IncrementalSync(
}
}
lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from)
lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from)
if err != nil {
req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed")
return from

View file

@ -3,17 +3,23 @@ package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type SendToDeviceStreamProvider struct {
StreamProvider
DefaultStreamProvider
}
func (p *SendToDeviceStreamProvider) Setup() {
p.StreamProvider.Setup()
func (p *SendToDeviceStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(ctx)
if err != nil {
panic(err)
}
@ -22,18 +28,20 @@ func (p *SendToDeviceStreamProvider) Setup() {
func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *SendToDeviceStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging.
lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from

View file

@ -5,24 +5,27 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type TypingStreamProvider struct {
StreamProvider
DefaultStreamProvider
EDUCache *caching.EDUCache
}
func (p *TypingStreamProvider) CompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *TypingStreamProvider) IncrementalSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {

View file

@ -0,0 +1,28 @@
package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type StreamProvider interface {
Setup(ctx context.Context, snapshot storage.DatabaseTransaction)
// Advance will update the latest position of the stream based on
// an update and will wake callers waiting on StreamNotifyAfter.
Advance(latest types.StreamPosition)
// CompleteSync will update the response to include all updates as needed
// for a complete sync. It will always return immediately.
CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition
// IncrementalSync will update the response to include all updates between
// the from and to sync positions. It will always return immediately,
// making no changes if the range contains no updates.
IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition
// LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) types.StreamPosition
}

View file

@ -4,6 +4,7 @@ import (
"context"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/notifier"
@ -13,15 +14,15 @@ import (
)
type Streams struct {
PDUStreamProvider types.StreamProvider
TypingStreamProvider types.StreamProvider
ReceiptStreamProvider types.StreamProvider
InviteStreamProvider types.StreamProvider
SendToDeviceStreamProvider types.StreamProvider
AccountDataStreamProvider types.StreamProvider
DeviceListStreamProvider types.StreamProvider
NotificationDataStreamProvider types.StreamProvider
PresenceStreamProvider types.StreamProvider
PDUStreamProvider StreamProvider
TypingStreamProvider StreamProvider
ReceiptStreamProvider StreamProvider
InviteStreamProvider StreamProvider
SendToDeviceStreamProvider StreamProvider
AccountDataStreamProvider StreamProvider
DeviceListStreamProvider StreamProvider
NotificationDataStreamProvider StreamProvider
PresenceStreamProvider StreamProvider
}
func NewSyncStreamProviders(
@ -31,52 +32,61 @@ func NewSyncStreamProviders(
) *Streams {
streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{
StreamProvider: StreamProvider{DB: d},
lazyLoadCache: lazyLoadCache,
rsAPI: rsAPI,
notifier: notifier,
DefaultStreamProvider: DefaultStreamProvider{DB: d},
lazyLoadCache: lazyLoadCache,
rsAPI: rsAPI,
notifier: notifier,
},
TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d},
EDUCache: eduCache,
DefaultStreamProvider: DefaultStreamProvider{DB: d},
EDUCache: eduCache,
},
ReceiptStreamProvider: &ReceiptStreamProvider{
StreamProvider: StreamProvider{DB: d},
DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
InviteStreamProvider: &InviteStreamProvider{
StreamProvider: StreamProvider{DB: d},
DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
StreamProvider: StreamProvider{DB: d},
DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
AccountDataStreamProvider: &AccountDataStreamProvider{
StreamProvider: StreamProvider{DB: d},
userAPI: userAPI,
DefaultStreamProvider: DefaultStreamProvider{DB: d},
userAPI: userAPI,
},
NotificationDataStreamProvider: &NotificationDataStreamProvider{
StreamProvider: StreamProvider{DB: d},
DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
DeviceListStreamProvider: &DeviceListStreamProvider{
StreamProvider: StreamProvider{DB: d},
rsAPI: rsAPI,
keyAPI: keyAPI,
DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI,
keyAPI: keyAPI,
},
PresenceStreamProvider: &PresenceStreamProvider{
StreamProvider: StreamProvider{DB: d},
notifier: notifier,
DefaultStreamProvider: DefaultStreamProvider{DB: d},
notifier: notifier,
},
}
streams.PDUStreamProvider.Setup()
streams.TypingStreamProvider.Setup()
streams.ReceiptStreamProvider.Setup()
streams.InviteStreamProvider.Setup()
streams.SendToDeviceStreamProvider.Setup()
streams.AccountDataStreamProvider.Setup()
streams.NotificationDataStreamProvider.Setup()
streams.DeviceListStreamProvider.Setup()
streams.PresenceStreamProvider.Setup()
ctx := context.TODO()
snapshot, err := d.NewDatabaseSnapshot(ctx)
if err != nil {
panic(err)
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
streams.PDUStreamProvider.Setup(ctx, snapshot)
streams.TypingStreamProvider.Setup(ctx, snapshot)
streams.ReceiptStreamProvider.Setup(ctx, snapshot)
streams.InviteStreamProvider.Setup(ctx, snapshot)
streams.SendToDeviceStreamProvider.Setup(ctx, snapshot)
streams.AccountDataStreamProvider.Setup(ctx, snapshot)
streams.NotificationDataStreamProvider.Setup(ctx, snapshot)
streams.DeviceListStreamProvider.Setup(ctx, snapshot)
streams.PresenceStreamProvider.Setup(ctx, snapshot)
succeeded = true
return streams
}

View file

@ -8,16 +8,18 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
)
type StreamProvider struct {
type DefaultStreamProvider struct {
DB storage.Database
latest types.StreamPosition
latestMutex sync.RWMutex
}
func (p *StreamProvider) Setup() {
func (p *DefaultStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseTransaction,
) {
}
func (p *StreamProvider) Advance(
func (p *DefaultStreamProvider) Advance(
latest types.StreamPosition,
) {
p.latestMutex.Lock()
@ -28,7 +30,7 @@ func (p *StreamProvider) Advance(
}
}
func (p *StreamProvider) LatestPosition(
func (p *DefaultStreamProvider) LatestPosition(
ctx context.Context,
) types.StreamPosition {
p.latestMutex.RLock()

View file

@ -31,6 +31,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
@ -305,78 +306,182 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
}
withTransaction := func(from types.StreamPosition, f func(snapshot storage.DatabaseTransaction) types.StreamPosition) types.StreamPosition {
var succeeded bool
snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to acquire database snapshot for sync request")
return from
}
defer func() {
succeeded = err == nil
sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
}()
return f(snapshot)
}
if syncReq.Since.IsEmpty() {
// Complete sync
syncReq.Response.NextBatch = types.StreamingToken{
// Get the current DeviceListPosition first, as the currentPosition
// might advance while processing other streams, resulting in flakey
// tests.
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
syncReq.Context, syncReq,
DeviceListPosition: withTransaction(
syncReq.Since.DeviceListPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.DeviceListStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
PDUPosition: rp.streams.PDUStreamProvider.CompleteSync(
syncReq.Context, syncReq,
PDUPosition: withTransaction(
syncReq.Since.PDUPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.PDUStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
TypingPosition: rp.streams.TypingStreamProvider.CompleteSync(
syncReq.Context, syncReq,
TypingPosition: withTransaction(
syncReq.Since.TypingPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.TypingStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync(
syncReq.Context, syncReq,
ReceiptPosition: withTransaction(
syncReq.Since.ReceiptPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.ReceiptStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
InvitePosition: rp.streams.InviteStreamProvider.CompleteSync(
syncReq.Context, syncReq,
InvitePosition: withTransaction(
syncReq.Since.InvitePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.InviteStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync(
syncReq.Context, syncReq,
SendToDevicePosition: withTransaction(
syncReq.Since.SendToDevicePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.SendToDeviceStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
syncReq.Context, syncReq,
AccountDataPosition: withTransaction(
syncReq.Since.AccountDataPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.AccountDataStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
syncReq.Context, syncReq,
NotificationDataPosition: withTransaction(
syncReq.Since.NotificationDataPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.NotificationDataStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync(
syncReq.Context, syncReq,
PresencePosition: withTransaction(
syncReq.Since.PresencePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.PresenceStreamProvider.CompleteSync(
syncReq.Context, txn, syncReq,
)
},
),
}
} else {
// Incremental sync
syncReq.Response.NextBatch = types.StreamingToken{
PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.PDUPosition, currentPos.PDUPosition,
PDUPosition: withTransaction(
syncReq.Since.PDUPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.PDUStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.PDUPosition, currentPos.PDUPosition,
)
},
),
TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.TypingPosition, currentPos.TypingPosition,
TypingPosition: withTransaction(
syncReq.Since.TypingPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.TypingStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.TypingPosition, currentPos.TypingPosition,
)
},
),
ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
ReceiptPosition: withTransaction(
syncReq.Since.ReceiptPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.ReceiptStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
)
},
),
InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.InvitePosition, currentPos.InvitePosition,
InvitePosition: withTransaction(
syncReq.Since.InvitePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.InviteStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.InvitePosition, currentPos.InvitePosition,
)
},
),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
SendToDevicePosition: withTransaction(
syncReq.Since.SendToDevicePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.SendToDeviceStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
)
},
),
AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
AccountDataPosition: withTransaction(
syncReq.Since.AccountDataPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.AccountDataStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
)
},
),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
NotificationDataPosition: withTransaction(
syncReq.Since.NotificationDataPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.NotificationDataStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
)
},
),
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
DeviceListPosition: withTransaction(
syncReq.Since.DeviceListPosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.DeviceListStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
)
},
),
PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.PresencePosition, currentPos.PresencePosition,
PresencePosition: withTransaction(
syncReq.Since.PresencePosition,
func(txn storage.DatabaseTransaction) types.StreamPosition {
return rp.streams.PresenceStreamProvider.IncrementalSync(
syncReq.Context, txn, syncReq,
syncReq.Since.PresencePosition, currentPos.PresencePosition,
)
},
),
}
// it's possible for there to be no updates for this user even though since < current pos,
@ -437,15 +542,23 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed")
return jsonerror.InternalServerError()
}
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to acquire database snapshot for key change")
return jsonerror.InternalServerError()
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("Failed to DeviceListCatchup info")
return jsonerror.InternalServerError()
}
succeeded = true
return util.JSONResponse{
Code: 200,
JSON: struct {

View file

@ -77,16 +77,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start presence consumer")
}
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
js, rsAPI, syncDB, notifier,
@ -98,15 +88,15 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer,
streams.InviteStreamProvider, rsAPI, base.Fulltext,
)
if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer")
}
clientConsumer := consumers.NewOutputClientDataConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
userAPIReadUpdateProducer,
base.ProcessContext, cfg, js, natsClient, syncDB, notifier,
streams.AccountDataStreamProvider, base.Fulltext,
)
if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer")
@ -135,7 +125,6 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
userAPIReadUpdateProducer,
)
if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer")
@ -143,6 +132,6 @@ func AddPublicRoutes(
routing.Setup(
base.PublicClientAPIMux, requestPool, syncDB, userAPI,
rsAPI, cfg, base.Caches,
rsAPI, cfg, base.Caches, base.Fulltext,
)
}

View file

@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool {
return false
}
}
type StreamProvider interface {
Setup()
// Advance will update the latest position of the stream based on
// an update and will wake callers waiting on StreamNotifyAfter.
Advance(latest StreamPosition)
// CompleteSync will update the response to include all updates as needed
// for a complete sync. It will always return immediately.
CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition
// IncrementalSync will update the response to include all updates between
// the from and to sync positions. It will always return immediately,
// making no changes if the range contains no updates.
IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition
// LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) StreamPosition
}

View file

@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool {
len(r.ToDevice.Events) == 0
}
type UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
}
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct {
Summary struct {
@ -419,10 +424,7 @@ type JoinResponse struct {
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"`
UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
} `json:"unread_notifications"`
*UnreadNotifications `json:"unread_notifications,omitempty"`
}
// NewJoinResponse creates an empty response with initialised arrays.
@ -503,19 +505,6 @@ type Peek struct {
Deleted bool
}
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}
// OutputReceiptEvent is an entry in the receipt output kafka log
type OutputReceiptEvent struct {
UserID string `json:"user_id"`

View file

@ -0,0 +1,127 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/util"
)
// OutputReceiptEventConsumer consumes events that originated in the clientAPI.
type OutputReceiptEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"),
db: store,
serverName: cfg.Matrix.ServerName,
syncProducer: syncProducer,
pgClient: pgClient,
}
}
// Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
roomID := msg.Header.Get(jetstream.RoomID)
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.serverName {
return true
}
metadata, err := msg.Metadata()
if err != nil {
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
return true
}

View file

@ -26,7 +26,7 @@ import (
"github.com/matrix-org/dendrite/userapi/util"
)
type OutputStreamEventConsumer struct {
type OutputRoomEventConsumer struct {
ctx context.Context
cfg *config.UserAPI
rsAPI rsapi.UserRoomserverAPI
@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct {
syncProducer *producers.SyncAPI
}
func NewOutputStreamEventConsumer(
func NewOutputRoomEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer(
pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer {
return &OutputStreamEventConsumer{
) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
pgClient: pgClient,
rsAPI: rsAPI,
syncProducer: syncProducer,
}
}
func (s *OutputStreamEventConsumer) Start() error {
func (s *OutputRoomEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error {
return nil
}
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
var output rsapi.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure")
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return true
}
if output.Event.Event == nil {
if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event
if event == nil {
log.Errorf("userapi consumer: expected event")
return true
}
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
"event_type": output.Event.Type(),
"stream_pos": output.StreamPosition,
}).Tracef("Received message from sync API: %#v", output)
"event_id": event.EventID(),
"event_type": event.Type(),
}).Tracef("Received message from roomserver: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
metadata, err := msg.Metadata()
if err != nil {
return true
}
if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil {
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
"event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure")
}
return true
}
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err)
@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
// removing it means we can send all notifications to
// e.g. Element's Push gateway in one go.
for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil {
log.WithFields(log.Fields{
"localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user")
}).WithError(err).Error("Unable to push to local user")
continue
}
}
@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership,
// localRoomMembers fetches the current local members of a room, and
// the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID,
JoinedOnly: true,
@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID
// looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event)
if err != nil {
@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
}
// notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil {
return err
@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
}).Debugf("Push rule evaluation rejected the event")
}).Tracef("Push rule evaluation rejected the event")
return nil
}
@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()),
}
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err
}
@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs,
}).Debugf("Notifying single member")
}).Trace("Notifying single member")
// Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user
@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
"room_id": event.RoomID(),
"localpart": mem.Localpart,
"rule_id": rule.RuleID,
}).Tracef("Matched a push rule")
}).Trace("Matched a push rule")
return rule.Actions, nil
}
@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil {
return nil, "", err
@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp
}
// notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{
"event_id": event.EventID(),
"url": url,
@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
}
}
logger.Debugf("Notifying push gateway %s", url)
logger.Tracef("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err
}
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result")
if len(res.Rejected) == 0 {
return nil, nil
@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,

View file

@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
consumer := OutputStreamEventConsumer{db: db}
consumer := OutputRoomEventConsumer{db: db}
testCases := []struct {
name string

View file

@ -1,137 +0,0 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -39,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
userapiUtil "github.com/matrix-org/dendrite/userapi/util"
)
type UserInternalAPI struct {
@ -51,6 +53,7 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
ignoredUsers = &synctypes.IgnoredUsers{}
_ = json.Unmarshal(req.AccountData, ignoredUsers)
}
if req.DataType == "m.fully_read" {
if err := a.setFullyRead(ctx, req); err != nil {
return err
}
}
if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
RoomID: req.RoomID,
Type: req.DataType,
@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
return nil
}
func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error {
var output eventutil.ReadMarkerJSON
if err := json.Unmarshal(req.AccountData, &output); err != nil {
return err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil
}
if domain != a.ServerName {
return nil
}
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
}
if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed")
return err
}
// nothing changed, no need to notify the push gateway
if !deleted {
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
return nil
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
@ -750,11 +796,6 @@ func (a *UserInternalAPI) PerformPushRulesPut(
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
Type: pushRulesAccountDataType,
}); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
}
return nil
}

View file

@ -4,12 +4,13 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
)
type JetStreamPublisher interface {

View file

@ -119,9 +119,9 @@ type ThreePID interface {
}
type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error)
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -20,12 +20,13 @@ import (
"encoding/json"
"time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return
}

View file

@ -19,11 +19,12 @@ import (
"database/sql"
"encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher)
}
logrus.Debugf("Database returned %d pushers", len(pushers))
logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}

View file

@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token)
}
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err
@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err

View file

@ -20,12 +20,13 @@ import (
"encoding/json"
"time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return
}

Some files were not shown because too many files have changed in this diff Show more