diff --git a/CHANGES.md b/CHANGES.md index 657ca1920..dbe2ccf02 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,38 @@ # Changelog +## Dendrite 0.10.1 (2022-09-30) + +### Features + +* The built-in NATS Server has been updated to version 2.9.2 + +### Fixes + +* A regression introduced in 0.10.0 in `/sync` as a result of transaction errors has been fixed +* Account data updates will no longer send duplicate output events + +## Dendrite 0.10.0 (2022-09-30) + +### Features + +* High performance full-text searching has been added to Dendrite + * Search must be enabled in the [`search` section of the `sync_api` config](https://github.com/matrix-org/dendrite/blob/6348486a1365c7469a498101f5035a9b6bd16d22/dendrite-sample.monolith.yaml#L279-L290) before it can be used + * The search index is stored on the filesystem rather than the sync API database, so a path to a suitable storage location on disk must be configured +* Sync requests should now complete faster and use considerably less database connections as a result of better transactional isolation +* The notifications code has been refactored to hopefully make notifications more reliable +* A new `/_dendrite/admin/refreshDevices/{userID}` admin endpoint has been added for forcing a refresh of a remote user's device lists without having to modify the database by hand +* A new `/_dendrite/admin/fulltext/reindex` admin endpoint has been added for rebuilding the search index (although this may take some time) + +### Fixes + +* A number of bugs in the device list updater have been fixed, which should help considerably with federated device list synchronisation and E2EE reliability +* A state resolution bug has been fixed which should help to prevent unexpected state resets +* The deprecated `"origin"` field in events will now be correctly ignored in all cases +* Room versions 8 and 9 will now correctly evaluate `"knock"` join rules and membership states +* A database index has been added to speed up finding room memberships in the sync API (contributed by [PiotrKozimor](https://github.com/PiotrKozimor)) +* The client API will now return an `M_UNRECOGNIZED` error for unknown endpoints/methods, which should help with client error handling +* A bug has been fixed when updating push rules which could result in `database is locked` on SQLite + ## Dendrite 0.9.9 (2022-09-22) ### Features diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 500403ae4..4a96e4bef 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -30,6 +30,8 @@ import ( "sync" "time" + "go.uber.org/atomic" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" @@ -66,6 +68,7 @@ const ( PeerTypeRemote = pineconeRouter.PeerTypeRemote PeerTypeMulticast = pineconeRouter.PeerTypeMulticast PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour ) type DendriteMonolith struct { @@ -82,6 +85,10 @@ type DendriteMonolith struct { userAPI userapiAPI.UserInternalAPI } +func (m *DendriteMonolith) PublicKey() string { + return m.PineconeRouter.PublicKey().String() +} + func (m *DendriteMonolith) BaseURL() string { return fmt.Sprintf("http://%s", m.listener.Addr().String()) } @@ -94,6 +101,20 @@ func (m *DendriteMonolith) SessionCount() int { return len(m.PineconeQUIC.Protocol("matrix").Sessions()) } +func (m *DendriteMonolith) RegisterNetworkInterface(name string, index int, mtu int, up bool, broadcast bool, loopback bool, pointToPoint bool, multicast bool, addrs string) { + m.PineconeMulticast.RegisterInterface(pineconeMulticast.InterfaceInfo{ + Name: name, + Index: index, + Mtu: mtu, + Up: up, + Broadcast: broadcast, + Loopback: loopback, + PointToPoint: pointToPoint, + Multicast: multicast, + Addrs: addrs, + }) +} + func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { if enabled { m.PineconeMulticast.Start() @@ -105,7 +126,9 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { func (m *DendriteMonolith) SetStaticPeer(uri string) { m.PineconeManager.RemovePeers() - m.PineconeManager.AddPeer(strings.TrimSpace(uri)) + for _, uri := range strings.Split(uri, ",") { + m.PineconeManager.AddPeer(strings.TrimSpace(uri)) + } } func (m *DendriteMonolith) DisconnectType(peertype int) { @@ -134,32 +157,21 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) go func() { conduit.portMutex.Lock() defer conduit.portMutex.Unlock() - loop: - for i := 1; i <= 10; i++ { - logrus.Errorf("Attempting authenticated connect (attempt %d)", i) - var err error - conduit.port, err = m.PineconeRouter.Connect( - l, - pineconeRouter.ConnectionZone(zone), - pineconeRouter.ConnectionPeerType(peertype), - ) - switch err { - case io.ErrClosedPipe: - logrus.Errorf("Authenticated connect failed due to closed pipe (attempt %d)", i) - return - case io.EOF: - logrus.Errorf("Authenticated connect failed due to EOF (attempt %d)", i) - break loop - case nil: - logrus.Errorf("Authenticated connect succeeded, connected to port %d (attempt %d)", conduit.port, i) - return - default: - logrus.WithError(err).Errorf("Authenticated connect failed (attempt %d)", i) - time.Sleep(time.Second) - } + + logrus.Errorf("Attempting authenticated connect") + var err error + if conduit.port, err = m.PineconeRouter.Connect( + l, + pineconeRouter.ConnectionZone(zone), + pineconeRouter.ConnectionPeerType(peertype), + ); err != nil { + logrus.Errorf("Authenticated connect failed: %s", err) + _ = l.Close() + _ = r.Close() + _ = conduit.Close() + return } - _ = l.Close() - _ = r.Close() + logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port) }() return conduit, nil } @@ -269,19 +281,21 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.InMemory = true - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) - cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) + cfg.Global.JetStream.InMemory = false + cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(m.CacheDirectory, prefix)) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) + cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) + cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + cfg.SyncAPI.Fulltext.Enabled = true + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(m.CacheDirectory, "search")) if err = cfg.Derive(); err != nil { panic(err) } @@ -395,6 +409,7 @@ func (m *DendriteMonolith) Stop() { const MaxFrameSize = types.MaxFrameSize type Conduit struct { + closed atomic.Bool conn net.Conn port types.SwitchPortID portMutex sync.Mutex @@ -407,10 +422,16 @@ func (c *Conduit) Port() int { } func (c *Conduit) Read(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } return c.conn.Read(b) } func (c *Conduit) ReadCopy() ([]byte, error) { + if c.closed.Load() { + return nil, io.EOF + } var buf [65535 * 2]byte n, err := c.conn.Read(buf[:]) if err != nil { @@ -420,9 +441,16 @@ func (c *Conduit) ReadCopy() ([]byte, error) { } func (c *Conduit) Write(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } return c.conn.Write(b) } func (c *Conduit) Close() error { + if c.closed.Load() { + return io.ErrClosedPipe + } + c.closed.Store(true) return c.conn.Close() } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 0c5f8c167..89c269f1a 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -2,16 +2,23 @@ package routing import ( "encoding/json" + "fmt" "net/http" + "time" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal/httputil" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + userapi "github.com/matrix-org/dendrite/userapi/api" ) func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { @@ -138,3 +145,49 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap }, } } + +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { + _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) + if err != nil { + logrus.WithError(err).Error("failed to publish nats message") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.ClientKeyAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + userID := vars["userID"] + + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if domain == cfg.Matrix.ServerName { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), + } + } + + err = keyAPI.PerformMarkAsStaleIfNeeded(req.Context(), &api.PerformMarkAsStaleRequest{ + UserID: userID, + Domain: domain, + }, &struct{}{}) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)), + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d7a48d228..7d1c434c4 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,6 +20,12 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" @@ -34,11 +40,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/nats-io/nats.go" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -161,6 +162,18 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/fulltext/reindex", + httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminReindex(req, cfg, device, natsClient) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", + httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminMarkAsStale(req, cfg, keyAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index a9357f6db..52301415f 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -64,7 +64,7 @@ var ( pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") isAdmin = flag.Bool("admin", false, "Create an admin account") resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index da63f9a2c..be34365b4 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -89,6 +89,7 @@ func main() { if configFlagSet { cfg = setup.ParseFlags(true) sk = cfg.Global.PrivateKey + pk = sk.Public().(ed25519.PublicKey) } else { keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" if _, err := os.Stat(keyfile); os.IsNotExist(err) { @@ -142,6 +143,9 @@ func main() { cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + cfg.MediaAPI.BasePath = config.Path(*instanceDir) + cfg.SyncAPI.Fulltext.Enabled = true + cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir) if err := cfg.Derive(); err != nil { panic(err) } diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index cd0066679..38c25cdec 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -134,6 +134,9 @@ func main() { cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + cfg.MediaAPI.BasePath = config.Path(*instanceDir) + cfg.SyncAPI.Fulltext.Enabled = true + cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir) if err := cfg.Derive(); err != nil { panic(err) } diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index c24e8153e..33b18c471 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -5,10 +5,11 @@ import ( "fmt" "path/filepath" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v2" + + "github.com/matrix-org/dendrite/setup/config" ) func main() { @@ -82,6 +83,12 @@ func main() { EnableInbound: true, EnableOutbound: true, } + cfg.SyncAPI.Fulltext = config.Fulltext{ + Enabled: true, + IndexPath: config.Path(filepath.Join(*dirPath, "searchindex")), + InMemory: true, + Language: "en", + } } } else { var err error diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml index f1758f54d..e41e83d7c 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.monolith.yaml @@ -275,10 +275,19 @@ sync_api: # address of the client. This is likely required if Dendrite is running behind # a reverse proxy server. # real_ip_header: X-Real-IP - fulltext: + + # Configuration for the full-text search engine. + search: + # Whether or not search is enabled. enabled: false - index_path: "./fulltextindex" - language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + + # The path where the search index will be created in. + index_path: "./searchindex" + + # The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + language: "en" # Configuration for the User API. user_api: diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index 97d10825f..0ae4cc8fb 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -326,10 +326,19 @@ sync_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - fulltext: + + # Configuration for the full-text search engine. + search: + # Whether or not search is enabled. enabled: false - index_path: "./fulltextindex" - language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + + # The path where the search index will be created in. + index_path: "./searchindex" + + # The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + language: "en" # This option controls which HTTP header to inspect to find the real remote IP # address of the client. This is likely required if Dendrite is running behind diff --git a/docs/administration/1_createusers.md b/docs/administration/1_createusers.md index 3468398ac..94399a04a 100644 --- a/docs/administration/1_createusers.md +++ b/docs/administration/1_createusers.md @@ -1,3 +1,4 @@ + --- title: Creating user accounts parent: Administration @@ -31,11 +32,11 @@ To create a new **admin account**, add the `-admin` flag: ./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin ``` -By default `create-account` uses `https://localhost:8448` to connect to Dendrite, this can be overwritten using +By default `create-account` uses `http://localhost:8008` to connect to Dendrite, this can be overwritten using the `-url` flag: ```bash -./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url http://localhost:8008 +./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url https://localhost:8448 ``` An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`: diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index a34bfde1f..56e19a8b4 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -57,6 +57,16 @@ Request body format: Reset the password of a local user. The `localpart` is the username only, i.e. if the full user ID is `@alice:domain.com` then the local part is `alice`. +## GET `/_dendrite/admin/fulltext/reindex` + +This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. +Indexing is done in the background, the server logs every 1000 events (or below) when they are being indexed. Once reindexing is done, you'll see something along the lines `Indexed 69586 events in 53.68223182s` in your debug logs. + +## POST `/_dendrite/admin/refreshDevices/{userID}` + +This endpoint instructs Dendrite to immediately query `/devices/{userID}` on a federated server. An empty JSON body will be returned on success, updating all locally stored user devices/keys. This can be used to possibly resolve E2EE issues, where the remote user can't decrypt messages. + + ## POST `/_synapse/admin/v1/send_server_notice` Request body format: diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index 7d7fc86bd..e7b3495f7 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -87,6 +87,12 @@ and contain the following JSON document: For example, this can be done with the following Caddy config: ``` +handle /.well-known/matrix/server { + header Content-Type application/json + header Access-Control-Allow-Origin * + respond `"m.server": "matrix.example.com:8448"` +} + handle /.well-known/matrix/client { header Content-Type application/json header Access-Control-Allow-Origin * diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 8fbe71c40..19958c92f 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -138,16 +138,18 @@ room_server: conn_max_lifetime: -1 ``` -## Fulltext search +## Full-text search -Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expections. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). +Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. + +Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). ```yaml sync_api: # ... - fulltext: + search: enabled: false - index_path: "./fulltextindex" + index_path: "./searchindex" language: "en" ``` diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index e44bad723..ffc1d8894 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -80,7 +80,6 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats return true } if originServerName != t.ServerName { - log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") return true } // Extract the send-to-device event from msg. diff --git a/go.mod b/go.mod index b682d9bc4..c82f76d41 100644 --- a/go.mod +++ b/go.mod @@ -22,11 +22,11 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 - github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d + github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621 + github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5 + github.com/nats-io/nats-server/v2 v2.9.2 github.com/nats-io/nats.go v1.17.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 @@ -43,7 +43,7 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.5-0.20220901155642-4f2abece817c go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 + golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be golang.org/x/image v0.0.0-20220902085622-e7cb96979f69 golang.org/x/mobile v0.0.0-20220722155234-aaac322e2105 golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1 @@ -91,7 +91,7 @@ require ( github.com/h2non/filetype v1.1.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/juju/errors v1.0.0 // indirect - github.com/klauspost/compress v1.15.10 // indirect + github.com/klauspost/compress v1.15.11 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/lucas-clemente/quic-go v0.29.0 // indirect github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect @@ -120,9 +120,9 @@ require ( go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect + golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec // indirect golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect - golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45 // indirect + golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect golang.org/x/tools v0.1.12 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 1afed73a5..a99599cb1 100644 --- a/go.sum +++ b/go.sum @@ -347,8 +347,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.15.10 h1:Ai8UzuomSCDw90e1qNMtb15msBXsNpH6gzkkENQNcJo= -github.com/klauspost/compress v1.15.10/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= +github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c= +github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -384,10 +384,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d h1:kGPJ6Rg8nn5an2CbCZrRiuTNyNzE0rRMiqm4UXJYrRs= -github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621 h1:a8IaoSPDxevkgXnOUrtIW9AqVNvXBJAG0gtnX687S7g= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c h1:iCHLYwwlPsf4TYFrvhKdhQoAM2lXzcmDZYqwBNWcnVk= +github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -422,8 +422,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5 h1:G/YGSXcJ2bUofD8Ts49it4VNezaJLQldI6fZR+wIUts= -github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5/go.mod h1:BWKY6217RvhI+FDoOLZ2BH+hOC37xeKRBlQ1Lz7teKI= +github.com/nats-io/nats-server/v2 v2.9.2 h1:XNDgJgOYYaYlquLdbSHI3xssLipfKUOq3EmYIMNCOsE= +github.com/nats-io/nats-server/v2 v2.9.2/go.mod h1:4sq8wvrpbvSzL1n3ZfEYnH4qeUuIl5W990j3kw13rRk= github.com/nats-io/nats.go v1.17.0 h1:1jp5BThsdGlN91hW0k3YEfJbfACjiOYtUiLXG0RL4IE= github.com/nats-io/nats.go v1.17.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= @@ -625,8 +625,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 h1:a5Yg6ylndHHYJqIPrdq0AhvR6KTvDTAvgBtaidhEevY= -golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be h1:fmw3UbQh+nxngCAHrDCCztao/kbYFnWjoqop8dHx05A= +golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -810,8 +810,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc= -golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec h1:BkDtF2Ih9xZ7le9ndzTA7KJow28VbQW3odyk/8drmuI= +golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20220919170432-7a66f970e087 h1:tPwmk4vmvVCMdr98VgL4JH+qZxPL8fqlUOHnyOM8N3w= @@ -829,8 +829,8 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45 h1:yuLAip3bfURHClMG9VBdzPrQvCWjWiWUTBGV+/fCbUs= -golang.org/x/time v0.0.0-20220920022843-2ce7c2934d45/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y= +golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/fulltext/bleve.go b/internal/fulltext/bleve.go index b07c0e51d..da8932f5c 100644 --- a/internal/fulltext/bleve.go +++ b/internal/fulltext/bleve.go @@ -22,8 +22,9 @@ import ( "github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2/mapping" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/setup/config" ) // Search contains all existing bleve.Index diff --git a/internal/fulltext/bleve_test.go b/internal/fulltext/bleve_test.go index 84a282423..d16397a45 100644 --- a/internal/fulltext/bleve_test.go +++ b/internal/fulltext/bleve_test.go @@ -27,11 +27,11 @@ import ( func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search { t.Helper() - cfg := config.Fulltext{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) + cfg := config.Fulltext{ + Enabled: true, + InMemory: true, + Language: "en", + } if tempDir != "" { cfg.IndexPath = config.Path(tempDir) cfg.InMemory = false diff --git a/internal/version.go b/internal/version.go index f9b101702..d508517be 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 9 - VersionPatch = 9 + VersionMinor = 10 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c9ec59a75..14fced3e8 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -45,6 +45,7 @@ type ClientKeyAPI interface { PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } // API functions required by the userapi diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 525f8a99d..fcfcd092d 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -17,6 +17,7 @@ package internal import ( "context" "encoding/json" + "errors" "fmt" "hash/fnv" "net" @@ -31,6 +32,7 @@ import ( fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/setup/process" ) var ( @@ -45,6 +47,9 @@ var ( ) ) +const defaultWaitTime = time.Minute +const requestTimeout = time.Second * 30 + func init() { prometheus.MustRegister( deviceListUpdateCount, @@ -80,6 +85,7 @@ func init() { // In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is // set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. type DeviceListUpdater struct { + process *process.ProcessContext // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 // request to the remote server and race. // TODO: Put in an LRU cache to bound growth @@ -131,10 +137,12 @@ type KeyChangeProducer interface { // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( - db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, + process *process.ProcessContext, db DeviceListUpdaterDatabase, + api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, ) *DeviceListUpdater { return &DeviceListUpdater{ + process: process, userIDToMutex: make(map[string]*sync.Mutex), mu: &sync.Mutex{}, db: db, @@ -234,7 +242,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. "prev_ids": event.PrevID, "display_name": event.DeviceDisplayName, "deleted": event.Deleted, - }).Info("DeviceListUpdater.Update") + }).Trace("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users if exists || event.Deleted { @@ -378,111 +386,123 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { } func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { - deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() - requestTimeout := time.Second * 30 // max amount of time we want to spend on each request - ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) - defer cancel() + ctx := u.process.Context() logger := util.GetLogger(ctx).WithField("server_name", serverName) - waitTime := 2 * time.Second - // fetch stale device lists + deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() + + waitTime := defaultWaitTime // How long should we wait to try again? + successCount := 0 // How many user requests failed? + userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) if err != nil { logger.WithError(err).Error("Failed to load stale device lists") return waitTime, true } - failCount := 0 -userLoop: + defer func() { + for _, userID := range userIDs { + // always clear the channel to unblock Update calls regardless of success/failure + u.clearChannel(userID) + } + }() + for _, userID := range userIDs { - if ctx.Err() != nil { - // we've timed out, give up and go to the back of the queue to let another server be processed. - failCount += 1 - waitTime = time.Minute * 10 + userWait, err := u.processServerUser(ctx, serverName, userID) + if err != nil { + if userWait > waitTime { + waitTime = userWait + } break } - res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) - if err != nil { - failCount += 1 - select { - case <-ctx.Done(): - // we've timed out, give up and go to the back of the queue to let another server be processed. - waitTime = time.Minute * 10 - break userLoop - default: - } - switch e := err.(type) { - case *fedsenderapi.FederationClientError: - if e.RetryAfter > 0 { - waitTime = e.RetryAfter - } else if e.Blacklisted { - waitTime = time.Hour * 8 - break userLoop - } else if e.Code >= 300 { - // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError - // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - waitTime = time.Hour - break userLoop - } - case net.Error: - // Use the default waitTime, if it's a timeout. - // It probably doesn't make sense to try further users. - if !e.Timeout() { - waitTime = time.Minute * 10 - logger.WithError(e).Error("GetUserDevices returned net.Error") - break userLoop - } - case gomatrix.HTTPError: - // The remote server returned an error, give it some time to recover. - // This is to avoid spamming remote servers, which may not be Matrix servers anymore. - if e.Code >= 300 { - waitTime = time.Hour - logger.WithError(e).Error("GetUserDevices returned gomatrix.HTTPError") - break userLoop - } - default: - // Something else failed - waitTime = time.Minute * 10 - logger.WithError(err).WithField("user_id", userID).Debugf("GetUserDevices returned unknown error type: %T", err) - break userLoop - } - continue - } - if res.MasterKey != nil || res.SelfSigningKey != nil { - uploadReq := &api.PerformUploadDeviceKeysRequest{ - UserID: userID, - } - uploadRes := &api.PerformUploadDeviceKeysResponse{} - if res.MasterKey != nil { - if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { - uploadReq.MasterKey = *res.MasterKey - } - } - if res.SelfSigningKey != nil { - if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { - uploadReq.SelfSigningKey = *res.SelfSigningKey - } - } - _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) - } - err = u.updateDeviceList(&res) - if err != nil { - logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it") - failCount += 1 - } + successCount++ } - if failCount > 0 { + + allUsersSucceeded := successCount == len(userIDs) + if !allUsersSucceeded { logger.WithFields(logrus.Fields{ - "total": len(userIDs), - "failed": failCount, - "skipped": len(userIDs) - failCount, - "waittime": waitTime, + "total": len(userIDs), + "succeeded": successCount, + "failed": len(userIDs) - successCount, + "wait_time": waitTime, }).Warn("Failed to query device keys for some users") } - for _, userID := range userIDs { - // always clear the channel to unblock Update calls regardless of success/failure - u.clearChannel(userID) + return waitTime, !allUsersSucceeded +} + +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ + "server_name": serverName, + "user_id": userID, + }) + + res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return time.Minute * 10, err + } + switch e := err.(type) { + case *json.UnmarshalTypeError, *json.SyntaxError: + logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID) + return defaultWaitTime, nil + case *fedsenderapi.FederationClientError: + if e.RetryAfter > 0 { + return e.RetryAfter, err + } else if e.Blacklisted { + return time.Hour * 8, err + } else if e.Code >= 300 { + // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError + // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. + return time.Hour, err + } + case net.Error: + // Use the default waitTime, if it's a timeout. + // It probably doesn't make sense to try further users. + if !e.Timeout() { + logger.WithError(e).Debug("GetUserDevices returned net.Error") + return time.Minute * 10, err + } + case gomatrix.HTTPError: + // The remote server returned an error, give it some time to recover. + // This is to avoid spamming remote servers, which may not be Matrix servers anymore. + if e.Code >= 300 { + logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") + return time.Hour, err + } + default: + // Something else failed + logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err) + return time.Minute * 10, err + } } - return waitTime, failCount > 0 + if res.UserID != userID { + logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID) + return defaultWaitTime, nil + } + if res.MasterKey != nil || res.SelfSigningKey != nil { + uploadReq := &api.PerformUploadDeviceKeysRequest{ + UserID: userID, + } + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if res.MasterKey != nil { + if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { + uploadReq.MasterKey = *res.MasterKey + } + } + if res.SelfSigningKey != nil { + if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { + uploadReq.SelfSigningKey = *res.SelfSigningKey + } + } + _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + } + err = u.updateDeviceList(&res) + if err != nil { + logger.WithError(err).Error("Fetched device list but failed to store/emit it") + return defaultWaitTime, err + } + return defaultWaitTime, nil } func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 0520a9e66..28a13a0a0 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/setup/process" ) var ( @@ -146,7 +147,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -218,7 +219,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -287,7 +288,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index a8d1128c4..017c29e84 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -228,14 +228,21 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query // PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present // in our database. func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{req.DeviceID}, true) + knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) if err != nil { return err } if len(knownDevices) == 0 { - return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) + return nil // fmt.Errorf("unknown user %s", req.UserID) } - return nil + + for i := range knownDevices { + if knownDevices[i].DeviceID == req.DeviceID { + return nil // we already know about this device + } + } + + return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) } // nolint:gocyclo diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 5124b777e..9ae4f9ca3 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -58,7 +58,7 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/setup/base/base.go b/setup/base/base.go index 0c7b222d0..0636c7b8d 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -37,16 +37,13 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/gorilla/mux" "github.com/kardianos/minwinsvc" @@ -61,6 +58,8 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" ) @@ -392,17 +391,26 @@ func (b *BaseDendrite) configureHTTPErrors() { _, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method))) } + clientNotFoundHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + } + notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler()) notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) for _, router := range []*mux.Router{ - b.PublicClientAPIMux, b.PublicMediaAPIMux, - b.DendriteAdminMux, b.SynapseAdminMux, - b.PublicWellKnownAPIMux, + b.PublicMediaAPIMux, b.DendriteAdminMux, + b.SynapseAdminMux, b.PublicWellKnownAPIMux, } { router.NotFoundHandler = notFoundCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler } + + // Special case so that we don't upset clients on the CS API. + b.PublicClientAPIMux.NotFoundHandler = http.HandlerFunc(clientNotFoundHandler) + b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) } // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index c890b0054..a87da3732 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -10,7 +10,7 @@ type SyncAPI struct { RealIPHeader string `yaml:"real_ip_header"` - Fulltext Fulltext `yaml:"fulltext"` + Fulltext Fulltext `yaml:"search"` } func (c *SyncAPI) Defaults(opts DefaultOpts) { @@ -50,18 +50,14 @@ type Fulltext struct { func (f *Fulltext) Defaults(opts DefaultOpts) { f.Enabled = false - f.IndexPath = "./fulltextindex" + f.IndexPath = "./searchindex" f.Language = "en" - if opts.Generate { - f.Enabled = true - f.InMemory = true - } } func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) { if !f.Enabled { return } - checkNotEmpty(configErrs, "syncapi.fulltext.index_path", string(f.IndexPath)) - checkNotEmpty(configErrs, "syncapi.fulltext.language", f.Language) + checkNotEmpty(configErrs, "syncapi.search.index_path", string(f.IndexPath)) + checkNotEmpty(configErrs, "syncapi.search.language", f.Language) } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 3660e91e3..7409fd6c8 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -9,9 +9,10 @@ import ( "time" "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" natsserver "github.com/nats-io/nats-server/v2/server" natsclient "github.com/nats-io/nats.go" @@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"}, + OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"}, + OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"}, } { streamName := cfg.Matrix.JetStream.Prefixed(stream) for _, consumer := range consumers { diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index c07d3a0b4..ee9810dae 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, - { - Name: OutputStreamEvent, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, - { - Name: OutputReadUpdate, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, { Name: OutputPresenceEvent, Retention: nats.InterestPolicy, diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f0588cab8..735f6718c 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -16,37 +16,42 @@ package consumers import ( "context" - "database/sql" "encoding/json" - "fmt" + "strings" + "time" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer + ctx context.Context + jetstream nats.JetStreamContext + nats *nats.Conn + durable string + topic string + topicReIndex string + db storage.Database + stream streams.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + fts *fulltext.Search + cfg *config.SyncAPI } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -54,26 +59,93 @@ func NewOutputClientDataConsumer( process *process.ProcessContext, cfg *config.SyncAPI, js nats.JetStreamContext, + nats *nats.Conn, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, - producer *producers.UserAPIReadProducer, + stream streams.StreamProvider, + fts *fulltext.Search, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), - durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"), - db: store, - notifier: notifier, - stream: stream, - serverName: cfg.Matrix.ServerName, - producer: producer, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), + topicReIndex: cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex), + durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"), + nats: nats, + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + fts: fts, + cfg: cfg, } } // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { + _, err := s.nats.Subscribe(s.topicReIndex, func(msg *nats.Msg) { + if err := msg.Ack(); err != nil { + return + } + if !s.cfg.Fulltext.Enabled { + logrus.Warn("Fulltext indexing is disabled") + return + } + ctx := context.Background() + logrus.Infof("Starting to index events") + var offset int + start := time.Now() + count := 0 + var id int64 = 0 + for { + evs, err := s.db.ReIndex(ctx, 1000, id) + if err != nil { + logrus.WithError(err).Errorf("unable to get events to index") + return + } + if len(evs) == 0 { + break + } + logrus.Debugf("Indexing %d events", len(evs)) + elements := make([]fulltext.IndexElement, 0, len(evs)) + + for streamPos, ev := range evs { + id = streamPos + e := fulltext.IndexElement{ + EventID: ev.EventID(), + RoomID: ev.RoomID(), + StreamPosition: streamPos, + } + e.SetContentType(ev.Type()) + + switch ev.Type() { + case "m.room.message": + e.Content = gjson.GetBytes(ev.Content(), "body").String() + case gomatrixserverlib.MRoomName: + e.Content = gjson.GetBytes(ev.Content(), "name").String() + case gomatrixserverlib.MRoomTopic: + e.Content = gjson.GetBytes(ev.Content(), "topic").String() + default: + continue + } + + if strings.TrimSpace(e.Content) == "" { + continue + } + elements = append(elements, e) + } + if err = s.fts.Index(elements...); err != nil { + logrus.WithError(err).Error("unable to index events") + continue + } + offset += len(evs) + count += len(elements) + } + logrus.Infof("Indexed %d events in %v", count, time.Since(start)) + }) + if err != nil { + return err + } return jetstream.JetStreamConsumer( s.ctx, s.jetstream, s.topic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), @@ -113,15 +185,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return false } - if err = s.sendReadUpdate(ctx, userID, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": userID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - if output.IgnoredUsers != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { log.WithError(err).WithFields(logrus.Fields{ @@ -136,34 +199,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return true } - -func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { - if output.Type != "m.fully_read" || output.ReadMarker == nil { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - var fullyReadPos types.StreamPosition - if output.ReadMarker.Read != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if output.ReadMarker.FullyRead != "" { - if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) - } - } - if readPos > 0 || fullyReadPos > 0 { - if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index c42e71971..dc7d9e207 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct { topic string db storage.Database notifier *notifier.Notifier - stream types.StreamProvider + stream streams.StreamProvider serverName gomatrixserverlib.ServerName // our server name rsAPI roomserverAPI.SyncRoomserverAPI } @@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer( rsAPI roomserverAPI.SyncRoomserverAPI, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputKeyChangeEventConsumer { s := &OutputKeyChangeEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 61bdc13de..145059c2d 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -39,7 +40,7 @@ type PresenceConsumer struct { requestTopic string presenceTopic string db storage.Database - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier deviceAPI api.SyncUserAPI cfg *config.SyncAPI @@ -54,7 +55,7 @@ func NewPresenceConsumer( nats *nats.Conn, db storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, deviceAPI api.SyncUserAPI, ) *PresenceConsumer { return &PresenceConsumer{ diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index a18244c44..8aaa65730 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -16,22 +16,20 @@ package consumers import ( "context" - "database/sql" - "fmt" "strconv" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. @@ -41,10 +39,9 @@ type OutputReceiptEventConsumer struct { durable string topic string db storage.Database - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -55,8 +52,7 @@ func NewOutputReceiptEventConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, - producer *producers.UserAPIReadProducer, + stream streams.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ ctx: process.Context(), @@ -67,7 +63,6 @@ func NewOutputReceiptEventConsumer( notifier: notifier, stream: stream, serverName: cfg.Matrix.ServerName, - producer: producer, } } @@ -111,42 +106,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - if err = s.sendReadUpdate(ctx, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": output.UserID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } - -func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error { - if output.Type != "m.read" { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - if output.EventID != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if readPos > 0 { - if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 6979eb484..c7a11dbb4 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -21,17 +21,21 @@ import ( "fmt" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. @@ -43,10 +47,10 @@ type OutputRoomEventConsumer struct { durable string topic string db storage.Database - pduStream types.StreamProvider - inviteStream types.StreamProvider + pduStream streams.StreamProvider + inviteStream streams.StreamProvider notifier *notifier.Notifier - producer *producers.UserAPIStreamEventProducer + fts *fulltext.Search } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -56,10 +60,10 @@ func NewOutputRoomEventConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - pduStream types.StreamProvider, - inviteStream types.StreamProvider, + pduStream streams.StreamProvider, + inviteStream streams.StreamProvider, rsAPI api.SyncRoomserverAPI, - producer *producers.UserAPIStreamEventProducer, + fts *fulltext.Search, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -72,7 +76,7 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, - producer: producer, + fts: fts, } } @@ -254,11 +258,11 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write new event failure") return nil } - - if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { - log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) - sentry.CaptureException(err) - return err + if err = s.writeFTS(ev, pduPos); err != nil { + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "type": ev.Type(), + }).WithError(err).Warn("failed to index fulltext element") } if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { @@ -304,6 +308,13 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( return nil } + if err = s.writeFTS(ev, pduPos); err != nil { + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "type": ev.Type(), + }).WithError(err).Warn("failed to index fulltext element") + } + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) return err @@ -440,8 +451,15 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head } stateKey := *event.StateKey() - prevEvent, err := s.db.GetStateEvent( - context.TODO(), event.RoomID(), event.Type(), stateKey, + snapshot, err := s.db.NewDatabaseSnapshot(s.ctx) + if err != nil { + return nil, err + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + prevEvent, err := snapshot.GetStateEvent( + s.ctx, event.RoomID(), event.Type(), stateKey, ) if err != nil { return event, err @@ -458,5 +476,42 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head } event.Event, err = event.SetUnsigned(prev) + succeeded = true return event, err } + +func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error { + if !s.cfg.Fulltext.Enabled { + return nil + } + e := fulltext.IndexElement{ + EventID: ev.EventID(), + RoomID: ev.RoomID(), + StreamPosition: int64(pduPosition), + } + e.SetContentType(ev.Type()) + + switch ev.Type() { + case "m.room.message": + e.Content = gjson.GetBytes(ev.Content(), "body").String() + case gomatrixserverlib.MRoomName: + e.Content = gjson.GetBytes(ev.Content(), "name").String() + case gomatrixserverlib.MRoomTopic: + e.Content = gjson.GetBytes(ev.Content(), "topic").String() + case gomatrixserverlib.MRoomRedaction: + log.Tracef("Redacting event: %s", ev.Redacts()) + if err := s.fts.Delete(ev.Redacts()); err != nil { + return fmt.Errorf("failed to delete entry from fulltext index: %w", err) + } + return nil + default: + return nil + } + if e.Content != "" { + log.Tracef("Indexing element: %+v", e) + if err := s.fts.Index(e); err != nil { + return err + } + } + return nil +} diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index c0b432256..49d84cca3 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct { db storage.Database keyAPI keyapi.SyncKeyAPI serverName gomatrixserverlib.ServerName // our server name - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier } @@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer( store storage.Database, keyAPI keyapi.SyncKeyAPI, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go index 88db80f8c..67a26239d 100644 --- a/syncapi/consumers/typing.go +++ b/syncapi/consumers/typing.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct { durable string topic string eduCache *caching.EDUCache - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier } @@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer( js nats.JetStreamContext, eduCache *caching.EDUCache, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputTypingEventConsumer { return &OutputTypingEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 227823522..3c73dc1fc 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -19,15 +19,17 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputNotificationDataConsumer consumes events that originated in @@ -39,7 +41,7 @@ type OutputNotificationDataConsumer struct { topic string db storage.Database notifier *notifier.Notifier - stream types.StreamProvider + stream streams.StreamProvider } // NewOutputNotificationDataConsumer creates a new consumer. Call @@ -50,7 +52,7 @@ func NewOutputNotificationDataConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputNotificationDataConsumer { s := &OutputNotificationDataConsumer{ ctx: process.Context(), diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index e73c004e5..bbfe19f4c 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) { // Returns the filtered events and an error, if any. func ApplyHistoryVisibilityFilter( ctx context.Context, - syncDB storage.Database, + syncDB storage.DatabaseTransaction, rsAPI api.SyncRoomserverAPI, events []*gomatrixserverlib.HeaderedEvent, alwaysIncludeEventIDs map[string]struct{}, diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 87f0d86d7..db18c6b77 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -318,18 +319,27 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener { func (n *Notifier) Load(ctx context.Context, db storage.Database) error { n.lock.Lock() defer n.lock.Unlock() - roomToUsers, err := db.AllJoinedUsersInRooms(ctx) + + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + return err + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx) if err != nil { return err } n.setUsersJoinedToRooms(roomToUsers) - roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx) + roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx) if err != nil { return err } n.setPeekingDevices(roomToPeekingDevices) + succeeded = true return nil } @@ -338,12 +348,20 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [ n.lock.Lock() defer n.lock.Unlock() - roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs) + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + return err + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs) if err != nil { return err } n.setUsersJoinedToRooms(roomToUsers) + succeeded = true return nil } diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go deleted file mode 100644 index d56cab776..000000000 --- a/syncapi/producers/userapi_readupdate.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIReadProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, userID) - m.Header.Set(jetstream.RoomID, roomID) - - data := types.ReadUpdate{ - UserID: userID, - RoomID: roomID, - Read: readPos, - FullyRead: fullyReadPos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "user_id": userID, - "room_id": roomID, - "read_pos": readPos, - "fully_read_pos": fullyReadPos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go deleted file mode 100644 index 2bbd19c0b..000000000 --- a/syncapi/producers/userapi_streamevent.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIStreamEventProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.RoomID, roomID) - - data := types.StreamedEvent{ - Event: event, - StreamPosition: pos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "room_id": roomID, - "event_id": event.EventID(), - "event_type": event.Type(), - "stream_pos": pos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 13c4e9d89..0ed164c7e 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" @@ -37,11 +38,11 @@ import ( type ContextRespsonse struct { End string `json:"end"` - Event gomatrixserverlib.ClientEvent `json:"event"` + Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"` EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` Start string `json:"start"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } func Context( @@ -51,6 +52,13 @@ func Context( roomID, eventID string, lazyLoadCache caching.LazyLoadCache, ) util.JSONResponse { + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + filter, err := parseRoomEventFilter(req) if err != nil { errMsg := "" @@ -97,7 +105,7 @@ func Context( ContainsURL: filter.ContainsURL, } - id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) + id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID) if err != nil { if err == sql.ErrNoRows { return util.JSONResponse{ @@ -111,7 +119,7 @@ func Context( // verify the user is allowed to see the context for this room/event startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") return jsonerror.InternalServerError() @@ -127,20 +135,20 @@ func Context( } } - eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) + eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") return jsonerror.InternalServerError() } - _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter) + _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") return jsonerror.InternalServerError() } startTime = time.Now() - eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID) + eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") return jsonerror.InternalServerError() @@ -152,7 +160,7 @@ func Context( }).Debug("applied history visibility (context eventsBefore/eventsAfter)") // TODO: Get the actual state at the last event returned by SelectContextAfterEvent - state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to fetch current room state") return jsonerror.InternalServerError() @@ -162,8 +170,9 @@ func Context( eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ - Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll), + Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), @@ -172,11 +181,12 @@ func Context( if len(response.State) > filter.Limit { response.State = response.State[len(response.State)-filter.Limit:] } - start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) + start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err == nil { response.End = end.String() response.Start = start.String() } + succeeded = true return util.JSONResponse{ Code: http.StatusOK, JSON: response, @@ -187,7 +197,7 @@ func Context( // by combining the events before and after the context event. Returns the filtered events, // and an error, if any. func applyHistoryVisibilityOnContextEvents( - ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI, + ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, userID string, ) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { @@ -204,7 +214,7 @@ func applyHistoryVisibilityOnContextEvents( } allEvents := append(eventsBefore, eventsAfter...) - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context") if err != nil { return nil, nil, err } @@ -221,15 +231,15 @@ func applyHistoryVisibilityOnContextEvents( return filteredBefore, filteredAfter, nil } -func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { +func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { if len(startEvents) > 0 { - start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) + start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID()) if err != nil { return } } if len(endEvents) > 0 { - end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID()) + end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID()) } return } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 03614302c..8f3ed3f5b 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -39,6 +40,7 @@ import ( type messagesReq struct { ctx context.Context db storage.Database + snapshot storage.DatabaseTransaction rsAPI api.SyncRoomserverAPI cfg *config.SyncAPI roomID string @@ -70,6 +72,16 @@ func OnIncomingMessagesRequest( ) util.JSONResponse { var err error + // NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we + // expect to be able to write to the database in response to a /messages + // request that requires backfilling from the roomserver or federation. + snapshot, err := db.NewDatabaseTransaction(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + // check if the user has already forgotten about this room isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) if err != nil { @@ -132,7 +144,7 @@ func OnIncomingMessagesRequest( } } else { fromStream = &streamToken - from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) + from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) return jsonerror.InternalServerError() @@ -154,7 +166,7 @@ func OnIncomingMessagesRequest( JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), } } else { - to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) + to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) return jsonerror.InternalServerError() @@ -165,7 +177,7 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), db, backwardOrdering, roomID) + to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") return jsonerror.InternalServerError() @@ -186,6 +198,7 @@ func OnIncomingMessagesRequest( mReq := messagesReq{ ctx: req.Context(), db: db, + snapshot: snapshot, rsAPI: rsAPI, cfg: cfg, roomID: roomID, @@ -217,7 +230,7 @@ func OnIncomingMessagesRequest( Start: start.String(), End: end.String(), } - res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache) + res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) // If we didn't return any events, set the end to an empty string, so it will be omitted // in the response JSON. @@ -229,6 +242,7 @@ func OnIncomingMessagesRequest( } // Respond with the events. + succeeded = true return util.JSONResponse{ Code: http.StatusOK, JSON: res, @@ -239,7 +253,7 @@ func OnIncomingMessagesRequest( // LazyLoadMembers enabled. func (m *messagesResp) applyLazyLoadMembers( ctx context.Context, - db storage.Database, + db storage.DatabaseTransaction, roomID string, device *userapi.Device, lazyLoad bool, @@ -292,7 +306,7 @@ func (r *messagesReq) retrieveEvents() ( end types.TopologyToken, err error, ) { // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) + streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return @@ -348,7 +362,7 @@ func (r *messagesReq) retrieveEvents() ( // Apply room history visibility filter startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), "room_id": r.roomID, @@ -366,7 +380,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st // else to go. This seems to fix Element iOS from looping on /messages endlessly. end = types.TopologyToken{} } else { - end, err = r.db.EventPositionInTopology( + end, err = r.snapshot.EventPositionInTopology( r.ctx, events[0].EventID(), ) // A stream/topological position is a cursor located between two events. @@ -378,7 +392,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st } } else { start = *r.from - end, err = r.db.EventPositionInTopology( + end, err = r.snapshot.EventPositionInTopology( r.ctx, events[len(events)-1].EventID(), ) } @@ -399,7 +413,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st func (r *messagesReq) handleEmptyEventsSlice() ( events []*gomatrixserverlib.HeaderedEvent, err error, ) { - backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID) // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { @@ -443,7 +457,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Check if the slice contains a backward extremity. - backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID) if err != nil { return } @@ -463,7 +477,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Append the events ve previously retrieved locally. - events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) + events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) sort.Sort(eventsByDepth(events)) return @@ -553,7 +567,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] // Returns an error if there was an issue with retrieving the latest position // from the database func setToDefault( - ctx context.Context, db storage.Database, backwardOrdering bool, + ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, roomID string, ) (to types.TopologyToken, err error) { if backwardOrdering { @@ -561,7 +575,7 @@ func setToDefault( // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. to = types.TopologyToken{} } else { - to, err = db.MaxTopologicalPosition(ctx, roomID) + to, err = snapshot.MaxTopologicalPosition(ctx, roomID) } return diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 6bc495d8d..8f84a1341 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,15 +18,18 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) // Setup configures the given mux with sync-server listeners @@ -40,6 +43,7 @@ func Setup( rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, lazyLoadCache caching.LazyLoadCache, + fts *fulltext.Search, ) { v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -95,4 +99,24 @@ func Setup( ) }), ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/search", + httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if !cfg.Fulltext.Enabled { + return util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("Search has been disabled by the server administrator."), + } + } + var nextBatch *string + if err := req.ParseForm(); err != nil { + return jsonerror.InternalServerError() + } + if req.Form.Has("next_batch") { + nb := req.FormValue("next_batch") + nextBatch = &nb + } + return Search(req, device, syncDB, fts, nextBatch) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go new file mode 100644 index 000000000..aef355def --- /dev/null +++ b/syncapi/routing/search.go @@ -0,0 +1,353 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/blevesearch/bleve/v2/search" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/userapi/api" +) + +// nolint:gocyclo +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse { + start := time.Now() + var ( + searchReq SearchRequest + err error + ctx = req.Context() + ) + resErr := httputil.UnmarshalJSONRequest(req, &searchReq) + if resErr != nil { + logrus.Error("failed to unmarshal search request") + return *resErr + } + + nextBatch := 0 + if from != nil && *from != "" { + nextBatch, err = strconv.Atoi(*from) + if err != nil { + return jsonerror.InternalServerError() + } + } + + if searchReq.SearchCategories.RoomEvents.Filter.Limit == 0 { + searchReq.SearchCategories.RoomEvents.Filter.Limit = 5 + } + + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + // only search rooms the user is actually joined to + joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join") + if err != nil { + return jsonerror.InternalServerError() + } + if len(joinedRooms) == 0 { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("User not joined to any rooms."), + } + } + joinedRoomsMap := make(map[string]struct{}, len(joinedRooms)) + for _, roomID := range joinedRooms { + joinedRoomsMap[roomID] = struct{}{} + } + rooms := []string{} + if searchReq.SearchCategories.RoomEvents.Filter.Rooms != nil { + for _, roomID := range *searchReq.SearchCategories.RoomEvents.Filter.Rooms { + if _, ok := joinedRoomsMap[roomID]; ok { + rooms = append(rooms, roomID) + } + } + } else { + rooms = joinedRooms + } + + if len(rooms) == 0 { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown("User not allowed to search in this room(s)."), + } + } + + orderByTime := searchReq.SearchCategories.RoomEvents.OrderBy == "recent" + + result, err := fts.Search( + searchReq.SearchCategories.RoomEvents.SearchTerm, + rooms, + searchReq.SearchCategories.RoomEvents.Keys, + searchReq.SearchCategories.RoomEvents.Filter.Limit, + nextBatch, + orderByTime, + ) + if err != nil { + logrus.WithError(err).Error("failed to search fulltext") + return jsonerror.InternalServerError() + } + logrus.Debugf("Search took %s", result.Took) + + // From was specified but empty, return no results, only the count + if from != nil && *from == "" { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + NextBatch: nil, + }, + }, + }, + } + } + + results := []Result{} + + wantEvents := make([]string, 0, len(result.Hits)) + eventScore := make(map[string]*search.DocumentMatch) + + for _, hit := range result.Hits { + wantEvents = append(wantEvents, hit.ID) + eventScore[hit.ID] = hit + } + + // Filter on m.room.message, as otherwise we also get events like m.reaction + // which "breaks" displaying results in Element Web. + types := []string{"m.room.message"} + roomFilter := &gomatrixserverlib.RoomEventFilter{ + Rooms: &rooms, + Types: &types, + } + + evs, err := syncDB.Events(ctx, wantEvents) + if err != nil { + logrus.WithError(err).Error("failed to get events from database") + return jsonerror.InternalServerError() + } + + groups := make(map[string]RoomResult) + knownUsersProfiles := make(map[string]ProfileInfo) + + // Sort the events by depth, as the returned values aren't ordered + if orderByTime { + sort.Slice(evs, func(i, j int) bool { + return evs[i].Depth() > evs[j].Depth() + }) + } + + stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) + for _, event := range evs { + eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) + if err != nil { + logrus.WithError(err).Error("failed to get context events") + return jsonerror.InternalServerError() + } + startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) + if err != nil { + logrus.WithError(err).Error("failed to get start/end") + return jsonerror.InternalServerError() + } + + profileInfos := make(map[string]ProfileInfo) + for _, ev := range append(eventsBefore, eventsAfter...) { + profile, ok := knownUsersProfiles[event.Sender()] + if !ok { + stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) + if err != nil { + logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + continue + } + if stateEvent == nil { + continue + } + profile = ProfileInfo{ + AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, + DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, + } + knownUsersProfiles[event.Sender()] = profile + } + profileInfos[ev.Sender()] = profile + } + + results = append(results, Result{ + Context: SearchContextResponse{ + Start: startToken.String(), + End: endToken.String(), + EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync), + EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync), + ProfileInfo: profileInfos, + }, + Rank: eventScore[event.EventID()].Score, + Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), + }) + roomGroup := groups[event.RoomID()] + roomGroup.Results = append(roomGroup.Results, event.EventID()) + groups[event.RoomID()] = roomGroup + if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { + stateFilter := gomatrixserverlib.DefaultStateFilter() + state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) + if err != nil { + logrus.WithError(err).Error("unable to get current state") + return jsonerror.InternalServerError() + } + stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync) + } + } + + var nextBatchResult *string = nil + if int(result.Total) > nextBatch+len(results) { + nb := strconv.Itoa(len(results) + nextBatch) + nextBatchResult = &nb + } else if int(result.Total) == nextBatch+len(results) { + // Sytest expects a next_batch even if we don't actually have any more results + nb := "" + nextBatchResult = &nb + } + + res := SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + Groups: Groups{RoomID: groups}, + Results: results, + NextBatch: nextBatchResult, + Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "), + State: stateForRooms, + }, + }, + } + + logrus.Debugf("Full search request took %v", time.Since(start)) + + succeeded = true + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} + +// contextEvents returns the events around a given eventID +func contextEvents( + ctx context.Context, + snapshot storage.DatabaseTransaction, + event *gomatrixserverlib.HeaderedEvent, + roomFilter *gomatrixserverlib.RoomEventFilter, + searchReq SearchRequest, +) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { + id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID()) + if err != nil { + logrus.WithError(err).Error("failed to query context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit + eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query before context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit + _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query after context event") + return nil, nil, err + } + return eventsBefore, eventsAfter, err +} + +type SearchRequest struct { + SearchCategories struct { + RoomEvents struct { + EventContext struct { + AfterLimit int `json:"after_limit,omitempty"` + BeforeLimit int `json:"before_limit,omitempty"` + IncludeProfile bool `json:"include_profile,omitempty"` + } `json:"event_context"` + Filter gomatrixserverlib.StateFilter `json:"filter"` + Groupings struct { + GroupBy []struct { + Key string `json:"key"` + } `json:"group_by"` + } `json:"groupings"` + IncludeState bool `json:"include_state"` + Keys []string `json:"keys"` + OrderBy string `json:"order_by"` + SearchTerm string `json:"search_term"` + } `json:"room_events"` + } `json:"search_categories"` +} + +type SearchResponse struct { + SearchCategories SearchCategories `json:"search_categories"` +} +type RoomResult struct { + NextBatch *string `json:"next_batch,omitempty"` + Order int `json:"order"` + Results []string `json:"results"` +} + +type Groups struct { + RoomID map[string]RoomResult `json:"room_id"` +} + +type Result struct { + Context SearchContextResponse `json:"context"` + Rank float64 `json:"rank"` + Result gomatrixserverlib.ClientEvent `json:"result"` +} + +type SearchContextResponse struct { + End string `json:"end"` + EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"` + EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"` + Start string `json:"start"` + ProfileInfo map[string]ProfileInfo `json:"profile_info"` +} + +type ProfileInfo struct { + AvatarURL string `json:"avatar_url"` + DisplayName string `json:"display_name"` +} + +type RoomEvents struct { + Count int `json:"count"` + Groups Groups `json:"groups"` + Highlights []string `json:"highlights"` + NextBatch *string `json:"next_batch,omitempty"` + Results []Result `json:"results"` + State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"` +} +type SearchCategories struct { + RoomEvents RoomEvents `json:"room_events"` +} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0c8ba4e3d..4a03aca74 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -17,17 +17,18 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) -type Database interface { - Presence +type DatabaseTransaction interface { + sqlutil.Transaction SharedUsers MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) @@ -36,6 +37,7 @@ type Database interface { MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) + MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) @@ -43,23 +45,77 @@ type Database interface { RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) - RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) - GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) - - InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) + InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) - // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) - // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) + // Events lookups a list of event by their event ID. + // Returns a list of events matching the requested IDs found in the database. + // If an event is not found in the database then it will be omitted from the list. + // Returns an error if there was a problem talking with the database. + // Does not include any transaction IDs in the returned events. + Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetStateEventsForRoom fetches the state events for a given room. + // Returns an empty slice if no state events could be found for this room. + // Returns an error if there was an issue with the retrieval. + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) + // GetAccountDataInRange returns all account data for a given user inserted or + // updated between two given positions + // Returns a map following the format data[roomID] = []dataTypes + // If no data is retrieved, returns an empty map + // If there was an issue with the retrieval, returns an error + GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) + // EventPositionInTopology returns the depth and stream position of the given event. + EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) + // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) + // MaxTopologicalPosition returns the highest topological position for a given room. + MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) + // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and + // matches the streamevent.transactionID device then the transaction ID gets + // added to the unsigned section of the output event. + StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent + // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the + // relevant events within the given ranges for the supplied user ID and device ID. + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) + // GetRoomReceipts gets all receipts for a given roomID + GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) + IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) + // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found + // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty + // string as the membership. + SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms + GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) + GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) +} + +type Database interface { + Presence + Notifications + + NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) + NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) + // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. @@ -76,20 +132,6 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error - // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key - // If no event could be found, returns nil - // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) - // GetStateEventsForRoom fetches the state events for a given room. - // Returns an empty slice if no state events could be found for this room. - // Returns an error if there was an issue with the retrieval. - GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) - // GetAccountDataInRange returns all account data for a given user inserted or - // updated between two given positions - // Returns a map following the format data[roomID] = []dataTypes - // If no data is retrieved, returns an empty map - // If there was an issue with the retrieval, returns an error - GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) @@ -113,21 +155,6 @@ type Database interface { // DeletePeek deletes all peeks for a given room by a given user // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) - // EventPositionInTopology returns the depth and stream position of the given event. - EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) - // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. - BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) - // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and - // matches the streamevent.transactionID device then the transaction ID gets - // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent - // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the - // relevant events within the given ranges for the supplied user ID and device ID. - SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified @@ -145,37 +172,21 @@ type Database interface { RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error // StoreReceipt stores new receipt events StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) - // GetRoomReceipts gets all receipts for a given roomID - GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) - - // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. - UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - - // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. - GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) - - SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) - - StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) - - IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found - // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty - // string as the membership. - SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) } type Presence interface { - UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) - MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) + UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } type SharedUsers interface { // SharedUsers returns a subset of otherUserIDs that share a room with userID. SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) } + +type Notifications interface { + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) +} diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index e9c72058b..aa54cb08f 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData( } func (s *accountDataStatements) SelectAccountDataInRange( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), + rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext( + ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), accountDataEventFilter.Limit, diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index d4515735c..8fc92091f 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID) if err != nil { return } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5e6daaaf8..2ccf0be1a 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -104,12 +104,7 @@ const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" const selectEventsWithEventIDsSQL = "" + - // TODO: The session_id and transaction_id blanks are here because - // the rowsToStreamEvents expects there to be exactly seven columns. We need to - // figure out if these really need to be in the DB, and if so, we need a - // better permanent fix for this. - neilalexander, 2 Jan 2020 - "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" + - " FROM syncapi_current_room_state WHERE event_id = ANY($1)" + "SELECT event_id, added_at, headered_event_json, history_visibility FROM syncapi_current_room_state WHERE event_id = ANY($1)" const selectSharedUsersSQL = "" + "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" + @@ -185,9 +180,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -209,9 +204,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } @@ -365,7 +360,36 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") - return rowsToStreamEvents(rows) + return currentRoomStateRowsToStreamEvents(rows) +} + +func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var events []types.StreamEvent + for rows.Next() { + var ( + eventID string + streamPos types.StreamPosition + eventBytes []byte + historyVisibility gomatrixserverlib.HistoryVisibility + ) + if err := rows.Scan(&eventID, &streamPos, &eventBytes, &historyVisibility); err != nil { + return nil, err + } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + ev.Visibility = historyVisibility + + events = append(events, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + }) + } + + return events, nil } func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { @@ -387,9 +411,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { diff --git a/syncapi/storage/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go index c82ef092f..86cec3625 100644 --- a/syncapi/storage/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { return err } @@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string @@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) + err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext( + ctx, localpart, filterJSON, + ).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { return "", err } @@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter( } // Otherwise insert the filter and return the new ID - err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). + err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart). Scan(&filterID) return } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 97001ae2c..aada70d5e 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -55,7 +55,7 @@ const deleteInviteEventSQL = "" + "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id" const selectInviteEventsInRangeSQL = "" + - "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC" @@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( return } - err = s.insertInviteEventStmt.QueryRowContext( + err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), @@ -121,23 +121,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { - return nil, nil, err + return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]*gomatrixserverlib.HeaderedEvent{} retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( + id types.StreamPosition roomID string eventJSON []byte deleted bool ) - if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { - return nil, nil, err + if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil { + return nil, nil, lastPos, err + } + if id > lastPos { + lastPos = id } // if we have seen this room before, it has a higher stream position and hence takes priority @@ -150,7 +155,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { - return nil, nil, err + return nil, nil, lastPos, err } if deleted { @@ -159,7 +164,10 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( result[roomID] = event } } - return result, retired, rows.Err() + if lastPos == 0 { + lastPos = r.To + } + return result, retired, lastPos, rows.Err() } func (s *inviteEventsStatements) SelectMaxInviteID( diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 708c3a9b4..2c7b24800 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro r := ¬ificationDataStatements{} return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, }.Prepare(db) } type notificationDataStatements struct { - upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt - selectMaxID *sql.Stmt + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCountsForRooms *sql.Stmt + selectMaxID *sql.Stmt } const notificationDataSchema = ` @@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 RETURNING id` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id = ANY($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs)) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 041f99061..cb092150d 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -166,6 +166,8 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -180,6 +182,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + selectSearchStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -215,15 +218,16 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID()) return err } @@ -632,3 +636,27 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { } return result, rows.Err() } + +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed") + + var eventID string + var id int64 + result := make(map[int64]gomatrixserverlib.HeaderedEvent) + for rows.Next() { + var ev gomatrixserverlib.HeaderedEvent + var eventBytes []byte + if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { + return nil, err + } + if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + result[id] = ev + } + return result, rows.Err() +} diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index a1fc9b2a3..6fab900eb 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos, spos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } @@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (topoPos types.StreamPosition, err error) { if backwardOrdering { - err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } else { - err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } return } @@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index 75eeac986..e20a4882f 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange( } func (s *peekStatements) SelectPeekingDevices( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index bbddaa939..327a7a372 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room return } -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { var lastPos types.StreamPosition - rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) + rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go new file mode 100644 index 000000000..937ced3a2 --- /dev/null +++ b/syncapi/storage/shared/storage_consumer.go @@ -0,0 +1,588 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/tidwall/gjson" + + userapi "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite +// For now this contains the shared functions +type Database struct { + DB *sql.DB + Writer sqlutil.Writer + Invites tables.Invites + Peeks tables.Peeks + AccountData tables.AccountData + OutputEvents tables.Events + Topology tables.Topology + CurrentRoomState tables.CurrentRoomState + BackwardExtremities tables.BackwardsExtremities + SendToDevice tables.SendToDevice + Filter tables.Filter + Receipts tables.Receipts + Memberships tables.Memberships + NotificationData tables.NotificationData + Ignores tables.Ignores + Presence tables.Presence +} + +func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { + return d.NewDatabaseTransaction(ctx) + + /* + TODO: Repeatable read is probably the right thing to do here, + but it seems to cause some problems with the invite tests, so + need to investigate that further. + + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil + */ +} + +func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { + txn, err := d.DB.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil +} + +func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) + if err != nil { + return nil, err + } + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +// AddInviteEvent stores a new invite event for a user. +// If the invite was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddInviteEvent( + ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) + return err + }) + return +} + +// RetireInviteEvent removes an old invite event from the database. +// Returns an error if there was a problem communicating with the database. +func (d *Database) RetireInviteEvent( + ctx context.Context, inviteEventID string, +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) + return err + }) + return +} + +// AddPeek tracks the fact that a user has started peeking. +// If the peek was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddPeek( + ctx context.Context, roomID, userID, deviceID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) + return err + }) + return +} + +// DeletePeek tracks the fact that a user has stopped peeking from the specified +// device. If the peeks was successfully deleted this returns the stream ID it was +// stored at. Returns an error if there was a problem communicating with the database. +func (d *Database) DeletePeek( + ctx context.Context, roomID, userID, deviceID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID) + return err + }) + if err == sql.ErrNoRows { + sp = 0 + err = nil + } + return +} + +// DeletePeeks tracks the fact that a user has stopped peeking from all devices +// If the peeks was successfully deleted this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) DeletePeeks( + ctx context.Context, roomID, userID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) + return err + }) + if err == sql.ErrNoRows { + sp = 0 + err = nil + } + return +} + +// UpsertAccountData keeps track of new or updated account data, by saving the type +// of the new/updated data, and the user ID and room ID the data is related to (empty) +// room ID means the data isn't specific to any room) +// If no data with the given type, user ID and room ID exists in the database, +// creates a new row, else update the existing one +// Returns an error if there was an issue with the upsert +func (d *Database) UpsertAccountData( + ctx context.Context, userID, roomID, dataType string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) + return +} + +func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out +} + +// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of +// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table +// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +// This function should always be called within a sqlutil.Writer for safety in SQLite. +func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { + if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { + return err + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { + return err + } + } + } + + return nil +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + +func (d *Database) WriteEvent( + ctx context.Context, + ev *gomatrixserverlib.HeaderedEvent, + addStateEvents []*gomatrixserverlib.HeaderedEvent, + addStateEventIDs, removeStateEventIDs []string, + transactionID *api.TransactionID, excludeFromSync bool, + historyVisibility gomatrixserverlib.HistoryVisibility, +) (pduPosition types.StreamPosition, returnErr error) { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + ev.Visibility = historyVisibility + pos, err := d.OutputEvents.InsertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, + ) + if err != nil { + return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) + } + pduPosition = pos + var topoPosition types.StreamPosition + if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { + return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) + } + + if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { + return fmt.Errorf("d.handleBackwardExtremities: %w", err) + } + + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { + // Nothing to do, the event may have just been a message event. + return nil + } + for i := range addStateEvents { + addStateEvents[i].Visibility = historyVisibility + } + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) + }) + + return pduPosition, returnErr +} + +// This function should always be called within a sqlutil.Writer for safety in SQLite. +func (d *Database) updateRoomState( + ctx context.Context, txn *sql.Tx, + removedEventIDs []string, + addedEvents []*gomatrixserverlib.HeaderedEvent, + pduPosition types.StreamPosition, + topoPosition types.StreamPosition, +) error { + // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. + for _, eventID := range removedEventIDs { + if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) + } + } + + for _, event := range addedEvents { + if event.StateKey() == nil { + // ignore non state events + continue + } + var membership *string + if event.Type() == "m.room.member" { + value, err := event.Membership() + if err != nil { + return fmt.Errorf("event.Membership: %w", err) + } + membership = &value + if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil { + return fmt.Errorf("d.Memberships.UpsertMembership: %w", err) + } + } + + if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { + return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) + } + } + + return nil +} + +func (d *Database) GetFilter( + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { + return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) +} + +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + var filterID string + var err error + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart) + return err + }) + return filterID, err +} + +func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { + redactedEvents, err := d.Events(ctx, []string{redactedEventID}) + if err != nil { + return err + } + if len(redactedEvents) == 0 { + logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") + return nil + } + eventToRedact := redactedEvents[0].Unwrap() + redactionEvent := redactedBecause.Unwrap() + if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { + return err + } + + newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) + }) + return err +} + +// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. +// Returns a map of room ID to list of events. +func (d *Database) fetchStateEvents( + ctx context.Context, txn *sql.Tx, + roomIDToEventIDSet map[string]map[string]bool, + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) + missingEvents := make(map[string][]string) + for roomID, ids := range roomIDToEventIDSet { + events := stateBetween[roomID] + for id, need := range ids { + if !need { + continue // deleted state + } + e, ok := eventIDToEvent[id] + if ok { + events = append(events, e) + } else { + m := missingEvents[roomID] + m = append(m, id) + missingEvents[roomID] = m + } + } + stateBetween[roomID] = events + } + + if len(missingEvents) > 0 { + // This happens when add_state_ids has an event ID which is not in the provided range. + // We need to explicitly fetch them. + allMissingEventIDs := []string{} + for _, missingEvIDs := range missingEvents { + allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) + } + evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) + if err != nil { + return nil, err + } + // we know we got them all otherwise an error would've been returned, so just loop the events + for _, ev := range evs { + roomID := ev.RoomID() + stateBetween[roomID] = append(stateBetween[roomID], ev) + } + } + return stateBetween, nil +} + +func (d *Database) fetchMissingStateEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StreamEvent, error) { + // Fetch from the events table first so we pick up the stream ID for the + // event. + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) + if err != nil { + return nil, err + } + + have := map[string]bool{} + for _, event := range events { + have[event.EventID()] = true + } + var missing []string + for _, eventID := range eventIDs { + if !have[eventID] { + missing = append(missing, eventID) + } + } + if len(missing) == 0 { + return events, nil + } + + // If they are missing from the events table then they should be state + // events that we received from outside the main event stream. + // These should be in the room state table. + stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) + + if err != nil { + return nil, err + } + if len(stateEvents) != len(missing) { + logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing)) + + // TODO: Why is this happening? It's probably the roomserver. Uncomment + // this error again when we work out what it is and fix it, otherwise we + // just end up returning lots of 500s to the client and that breaks + // pretty much everything, rather than just sending what we have. + //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) + } + events = append(events, stateEvents...) + return events, nil +} + +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, +) (newPos types.StreamPosition, err error) { + j, err := json.Marshal(event) + if err != nil { + return 0, err + } + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + newPos, err = d.SendToDevice.InsertSendToDeviceMessage( + ctx, txn, userID, deviceID, string(j), + ) + return err + }) + if err != nil { + return 0, err + } + return newPos, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + userID, deviceID string, before types.StreamPosition, +) (err error) { + if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before) + }); err != nil { + logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID) + return err + } + return nil +} + +// getMembershipFromEvent returns the value of content.membership iff the event is a state event +// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. +func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { + if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { + return "", "" + } + membership, err := ev.Membership() + if err != nil { + return "", "" + } + prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str + return membership, prevMembership +} + +// StoreReceipt stores user receipts +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) + return err + }) + return +} + +func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount) + return err + }) + return +} + +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) +} + +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) +} +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) +} + +func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { + return d.Ignores.SelectIgnores(ctx, nil, userID) +} + +func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores) + }) +} + +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + var pos types.StreamPosition + var err error + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync) + return nil + }) + return pos, err +} + +func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, nil, userID) +} + +func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { + return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) +} + +func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ + gomatrixserverlib.MRoomName, + gomatrixserverlib.MRoomTopic, + "m.room.message", + }) +} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go new file mode 100644 index 000000000..0e19d97d2 --- /dev/null +++ b/syncapi/storage/shared/storage_sync.go @@ -0,0 +1,575 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type DatabaseTransaction struct { + *Database + ctx context.Context + txn *sql.Tx +} + +func (d *DatabaseTransaction) Commit() error { + if d.txn == nil { + return nil + } + return d.txn.Commit() +} + +func (d *DatabaseTransaction) Rollback() error { + if d.txn == nil { + return nil + } + return d.txn.Rollback() +} + +func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { + id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { + id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { + id, err := d.Invites.SelectMaxInviteID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { + id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.NotificationData.SelectMaxID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) +} + +func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { + return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership) +} + +func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { + return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) +} + +func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { + return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +} + +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) +} + +func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { + return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) +} + +func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r) +} + +func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { + return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r) +} + +func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { + return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) +} + +// Events lookups a list of event by their event ID. +// Returns a list of events matching the requested IDs found in the database. +// If an event is not found in the database then it will be omitted from the list. +// Returns an error if there was a problem talking with the database. +// Does not include any transaction IDs in the returned events. +func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false) + if err != nil { + return nil, err + } + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn) +} + +func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs) +} + +func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { + return d.Peeks.SelectPeekingDevices(ctx, d.txn) +} + +func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { + return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs) +} + +func (d *DatabaseTransaction) GetStateEvent( + ctx context.Context, roomID, evType, stateKey string, +) (*gomatrixserverlib.HeaderedEvent, error) { + return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey) +} + +func (d *DatabaseTransaction) GetStateEventsForRoom( + ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, +) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) + return +} + +// GetAccountDataInRange returns all account data for a given user inserted or +// updated between two given positions +// Returns a map following the format data[roomID] = []dataTypes +// If no data is retrieved, returns an empty map +// If there was an issue with the retrieval, returns an error +func (d *DatabaseTransaction) GetAccountDataInRange( + ctx context.Context, userID string, r types.Range, + accountDataFilterPart *gomatrixserverlib.EventFilter, +) (map[string][]string, types.StreamPosition, error) { + return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart) +} + +func (d *DatabaseTransaction) GetEventsInTopologicalRange( + ctx context.Context, + from, to *types.TopologyToken, + roomID string, + filter *gomatrixserverlib.RoomEventFilter, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition + if backwardOrdering { + // Backward ordering means the 'from' token has a higher depth than the 'to' token + minDepth = to.Depth + maxDepth = from.Depth + // for cases where we have say 5 events with the same depth, the TopologyToken needs to + // know which of the 5 the client has seen. This is done by using the PDU position. + // Events with the same maxDepth but less than this PDU position will be returned. + maxStreamPosForMaxDepth = from.PDUPosition + } else { + // Forward ordering means the 'from' token has a lower depth than the 'to' token. + minDepth = from.Depth + maxDepth = to.Depth + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.Topology.SelectEventIDsInRange( + ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true) + return +} + +func (d *DatabaseTransaction) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities map[string][]string, err error) { + return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) +} + +func (d *DatabaseTransaction) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.TopologyToken, error) { + depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) + if err != nil { + return types.TopologyToken{}, err + } + return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil +} + +func (d *DatabaseTransaction) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.TopologyToken, error) { + depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) + if err != nil { + return types.TopologyToken{}, err + } + return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil +} + +func (d *DatabaseTransaction) StreamToTopologicalPosition( + ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, +) (types.TopologyToken, error) { + topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering) + switch { + case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward + return types.TopologyToken{PDUPosition: streamPos}, nil + case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward + topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) + if err != nil { + return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) + } + return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + case err != nil: // some other error happened + return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) + default: + return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + } +} + +// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the +// oldest event in the room's topology. +func (d *DatabaseTransaction) GetBackwardTopologyPos( + ctx context.Context, + events []types.StreamEvent, +) (types.TopologyToken, error) { + zeroToken := types.TopologyToken{} + if len(events) == 0 { + return zeroToken, nil + } + pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID()) + if err != nil { + return zeroToken, err + } + tok := types.TopologyToken{Depth: pos, PDUPosition: spos} + tok.Decrement() + return tok, nil +} + +// GetStateDeltas returns the state deltas between fromPos and toPos, +// exclusive of oldPos, inclusive of newPos, for the rooms in which +// the user has new membership events. +// A list of joined room IDs is also returned in case the caller needs it. +func (d *DatabaseTransaction) GetStateDeltas( + ctx context.Context, device *userapi.Device, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { + // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 + // - Get membership list changes for this user in this sync response + // - For each room which has membership list changes: + // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). + // If it is, then we need to send the full room state down (and 'limited' is always true). + // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. + // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. + // - Get all CURRENTLY joined rooms, and add them to 'joined' block. + + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + + // get all the state events ever (i.e. for all available rooms) between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + // find out which rooms this user is peeking, if any. + // We do this before joins so any peeks get overwritten + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) + if err != nil && err != sql.ErrNoRows { + return nil, nil, err + } + + // add peek blocks + for _, peek := range peeks { + if peek.New { + // send full room state down instead of a delta + var s []types.StreamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter) + if err != nil { + if err == sql.ErrNoRows { + continue + } + return nil, nil, err + } + state[peek.RoomID] = s + } + if !peek.Deleted { + deltas = append(deltas, types.StateDelta{ + Membership: gomatrixserverlib.Peek, + StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + RoomID: peek.RoomID, + }) + } + } + + // handle newly joined rooms and non-joined rooms + newlyJoinedRooms := make(map[string]bool, len(state)) + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership == gomatrixserverlib.Join && prevMembership != membership { + // send full room state down instead of a delta + var s []types.StreamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter) + if err != nil { + if err == sql.ErrNoRows { + continue + } + return nil, nil, err + } + state[roomID] = s + newlyJoinedRooms[roomID] = true + continue // we'll add this room in when we do joined rooms + } + + deltas = append(deltas, types.StateDelta{ + Membership: membership, + MembershipPos: ev.StreamPosition, + StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + RoomID: roomID, + }) + break + } + } + } + + // Add in currently joined rooms + for _, joinedRoomID := range joinedRoomIDs { + deltas = append(deltas, types.StateDelta{ + Membership: gomatrixserverlib.Join, + StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + RoomID: joinedRoomID, + NewlyJoined: newlyJoinedRooms[joinedRoomID], + }) + } + + return deltas, joinedRoomIDs, nil +} + +// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// requests with full_state=true. +// Fetches full state for all joined rooms and uses selectStateInRange to get +// updates for other rooms. +func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( + ctx context.Context, device *userapi.Device, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]types.StateDelta, []string, error) { + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + + // Use a reasonable initial capacity + deltas := make(map[string]types.StateDelta) + + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) + if err != nil && err != sql.ErrNoRows { + return nil, nil, err + } + + // Add full states for all peeking rooms + for _, peek := range peeks { + if !peek.Deleted { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter) + if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } + return nil, nil, stateErr + } + deltas[peek.RoomID] = types.StateDelta{ + Membership: gomatrixserverlib.Peek, + StateEvents: d.StreamEventsToEvents(device, s), + RoomID: peek.RoomID, + } + } + } + + // Get all the state events ever between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + deltas[roomID] = types.StateDelta{ + Membership: membership, + MembershipPos: ev.StreamPosition, + StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + RoomID: roomID, + } + } + + break + } + } + } + + // Add full states for all joined rooms + for _, joinedRoomID := range joinedRoomIDs { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter) + if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } + return nil, nil, stateErr + } + deltas[joinedRoomID] = types.StateDelta{ + Membership: gomatrixserverlib.Join, + StateEvents: d.StreamEventsToEvents(device, s), + RoomID: joinedRoomID, + } + } + + // Create a response array. + result := make([]types.StateDelta, len(deltas)) + i := 0 + for _, delta := range deltas { + result[i] = delta + i++ + } + + return result, joinedRoomIDs, nil +} + +func (d *DatabaseTransaction) currentStateStreamEventsForRoom( + ctx context.Context, roomID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]types.StreamEvent, error) { + allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) + if err != nil { + return nil, err + } + s := make([]types.StreamEvent, len(allState)) + for i := 0; i < len(s); i++ { + s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} + } + return s, nil +} + +func (d *DatabaseTransaction) SendToDeviceUpdatesForSync( + ctx context.Context, + userID, deviceID string, + from, to types.StreamPosition, +) (types.StreamPosition, []types.SendToDeviceEvent, error) { + // First of all, get our send-to-device updates for this user. + lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to) + if err != nil { + return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } + // If there's nothing to do then stop here. + if len(events) == 0 { + return to, nil, nil + } + return lastPos, events, nil +} + +func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { + _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) + return receipts, err +} + +func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { + roomIDs := make([]string, 0, len(rooms)) + for roomID, membership := range rooms { + if membership != gomatrixserverlib.Join { + continue + } + roomIDs = append(roomIDs, roomID) + } + return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) +} + +func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +} + +func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) +} + +func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, d.txn) +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go deleted file mode 100644 index 778ad8b18..000000000 --- a/syncapi/storage/shared/syncserver.go +++ /dev/null @@ -1,1088 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/tidwall/gjson" - - userapi "github.com/matrix-org/dendrite/userapi/api" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/syncapi/storage/tables" - "github.com/matrix-org/dendrite/syncapi/types" -) - -// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite -// For now this contains the shared functions -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - Invites tables.Invites - Peeks tables.Peeks - AccountData tables.AccountData - OutputEvents tables.Events - Topology tables.Topology - CurrentRoomState tables.CurrentRoomState - BackwardExtremities tables.BackwardsExtremities - SendToDevice tables.SendToDevice - Filter tables.Filter - Receipts tables.Receipts - Memberships tables.Memberships - NotificationData tables.NotificationData - Ignores tables.Ignores - Presence tables.Presence -} - -func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) -} - -func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { - id, err := d.OutputEvents.SelectMaxEventID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Receipts.SelectMaxReceiptID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Invites.SelectMaxInviteID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { - id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.NotificationData.SelectMaxID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) -} - -func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { - return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) -} - -func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { - return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos) -} - -func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships) -} - -func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) -} - -func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { - return d.Topology.SelectPositionInTopology(ctx, nil, eventID) -} - -func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { - return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r) -} - -func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { - return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r) -} - -func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { - return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) -} - -// Events lookups a list of event by their event ID. -// Returns a list of events matching the requested IDs found in the database. -// If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database. -// Does not include any transaction IDs in the returned events. -func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) - if err != nil { - return nil, err - } - - // We don't include a device here as we only include transaction IDs in - // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil -} - -func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsers(ctx) -} - -func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs) -} - -func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { - return d.Peeks.SelectPeekingDevices(ctx) -} - -func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { - return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs) -} - -func (d *Database) GetStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) -} - -func (d *Database) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, -) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil) - return -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) AddInviteEvent( - ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, -) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) - return err - }) - return -} - -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. -func (d *Database) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) - return err - }) - return -} - -// AddPeek tracks the fact that a user has started peeking. -// If the peek was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) AddPeek( - ctx context.Context, roomID, userID, deviceID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) - return err - }) - return -} - -// DeletePeek tracks the fact that a user has stopped peeking from the specified -// device. If the peeks was successfully deleted this returns the stream ID it was -// stored at. Returns an error if there was a problem communicating with the database. -func (d *Database) DeletePeek( - ctx context.Context, roomID, userID, deviceID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID) - return err - }) - if err == sql.ErrNoRows { - sp = 0 - err = nil - } - return -} - -// DeletePeeks tracks the fact that a user has stopped peeking from all devices -// If the peeks was successfully deleted this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) DeletePeeks( - ctx context.Context, roomID, userID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) - return err - }) - if err == sql.ErrNoRows { - sp = 0 - err = nil - } - return -} - -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *Database) GetAccountDataInRange( - ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, types.StreamPosition, error) { - return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *Database) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) - return err - }) - return -} - -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of -// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table -// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. -// This function should always be called within a sqlutil.Writer for safety in SQLite. -func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { - if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { - return err - } - - // Check if we have all of the event's previous events. If an event is - // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) - if err != nil { - return err - } - var found bool - for _, eID := range ev.PrevEventIDs() { - found = false - for _, prevEv := range prevEvents { - if eID == prevEv.EventID() { - found = true - } - } - - // If the event is missing, consider it a backward extremity. - if !found { - if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { - return err - } - } - } - - return nil -} - -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - -func (d *Database) WriteEvent( - ctx context.Context, - ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []*gomatrixserverlib.HeaderedEvent, - addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, excludeFromSync bool, - historyVisibility gomatrixserverlib.HistoryVisibility, -) (pduPosition types.StreamPosition, returnErr error) { - returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - ev.Visibility = historyVisibility - pos, err := d.OutputEvents.InsertEvent( - ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, - ) - if err != nil { - return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) - } - pduPosition = pos - var topoPosition types.StreamPosition - if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { - return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) - } - - if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return fmt.Errorf("d.handleBackwardExtremities: %w", err) - } - - if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { - // Nothing to do, the event may have just been a message event. - return nil - } - for i := range addStateEvents { - addStateEvents[i].Visibility = historyVisibility - } - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) - }) - - return pduPosition, returnErr -} - -// This function should always be called within a sqlutil.Writer for safety in SQLite. -func (d *Database) updateRoomState( - ctx context.Context, txn *sql.Tx, - removedEventIDs []string, - addedEvents []*gomatrixserverlib.HeaderedEvent, - pduPosition types.StreamPosition, - topoPosition types.StreamPosition, -) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removedEventIDs { - if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) - } - } - - for _, event := range addedEvents { - if event.StateKey() == nil { - // ignore non state events - continue - } - var membership *string - if event.Type() == "m.room.member" { - value, err := event.Membership() - if err != nil { - return fmt.Errorf("event.Membership: %w", err) - } - membership = &value - if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil { - return fmt.Errorf("d.Memberships.UpsertMembership: %w", err) - } - } - - if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) - } - } - - return nil -} - -func (d *Database) GetEventsInTopologicalRange( - ctx context.Context, - from, to *types.TopologyToken, - roomID string, - filter *gomatrixserverlib.RoomEventFilter, - backwardOrdering bool, -) (events []types.StreamEvent, err error) { - var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition - if backwardOrdering { - // Backward ordering means the 'from' token has a higher depth than the 'to' token - minDepth = to.Depth - maxDepth = from.Depth - // for cases where we have say 5 events with the same depth, the TopologyToken needs to - // know which of the 5 the client has seen. This is done by using the PDU position. - // Events with the same maxDepth but less than this PDU position will be returned. - maxStreamPosForMaxDepth = from.PDUPosition - } else { - // Forward ordering means the 'from' token has a lower depth than the 'to' token. - minDepth = from.Depth - maxDepth = to.Depth - } - - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) - return -} - -func (d *Database) BackwardExtremitiesForRoom( - ctx context.Context, roomID string, -) (backwardExtremities map[string][]string, err error) { - return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) -} - -func (d *Database) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - -func (d *Database) EventPositionInTopology( - ctx context.Context, eventID string, -) (types.TopologyToken, error) { - depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil -} - -func (d *Database) StreamToTopologicalPosition( - ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, -) (types.TopologyToken, error) { - topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, nil, roomID, streamPos, backwardOrdering) - switch { - case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward - return types.TopologyToken{PDUPosition: streamPos}, nil - case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil - case err != nil: // some other error happened - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) - default: - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil - } -} - -func (d *Database) GetFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, -) error { - return d.Filter.SelectFilter(ctx, target, localpart, filterID) -} - -func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, -) (string, error) { - var filterID string - var err error - err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - filterID, err = d.Filter.InsertFilter(ctx, filter, localpart) - return err - }) - return filterID, err -} - -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { - redactedEvents, err := d.Events(ctx, []string{redactedEventID}) - if err != nil { - return err - } - if len(redactedEvents) == 0 { - logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") - return nil - } - eventToRedact := redactedEvents[0].Unwrap() - redactionEvent := redactedBecause.Unwrap() - if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { - return err - } - - newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) - err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.OutputEvents.UpdateEventJSON(ctx, newEvent) - }) - return err -} - -// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the -// oldest event in the room's topology. -func (d *Database) GetBackwardTopologyPos( - ctx context.Context, - events []types.StreamEvent, -) (types.TopologyToken, error) { - zeroToken := types.TopologyToken{} - if len(events) == 0 { - return zeroToken, nil - } - pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID()) - if err != nil { - return zeroToken, err - } - tok := types.TopologyToken{Depth: pos, PDUPosition: spos} - tok.Decrement() - return tok, nil -} - -// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. -// Returns a map of room ID to list of events. -func (d *Database) fetchStateEvents( - ctx context.Context, txn *sql.Tx, - roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]types.StreamEvent, -) (map[string][]types.StreamEvent, error) { - stateBetween := make(map[string][]types.StreamEvent) - missingEvents := make(map[string][]string) - for roomID, ids := range roomIDToEventIDSet { - events := stateBetween[roomID] - for id, need := range ids { - if !need { - continue // deleted state - } - e, ok := eventIDToEvent[id] - if ok { - events = append(events, e) - } else { - m := missingEvents[roomID] - m = append(m, id) - missingEvents[roomID] = m - } - } - stateBetween[roomID] = events - } - - if len(missingEvents) > 0 { - // This happens when add_state_ids has an event ID which is not in the provided range. - // We need to explicitly fetch them. - allMissingEventIDs := []string{} - for _, missingEvIDs := range missingEvents { - allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) - } - evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) - if err != nil { - return nil, err - } - // we know we got them all otherwise an error would've been returned, so just loop the events - for _, ev := range evs { - roomID := ev.RoomID() - stateBetween[roomID] = append(stateBetween[roomID], ev) - } - } - return stateBetween, nil -} - -func (d *Database) fetchMissingStateEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]types.StreamEvent, error) { - // Fetch from the events table first so we pick up the stream ID for the - // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) - if err != nil { - return nil, err - } - - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range eventIDs { - if !have[eventID] { - missing = append(missing, eventID) - } - } - if len(missing) == 0 { - return events, nil - } - - // If they are missing from the events table then they should be state - // events that we received from outside the main event stream. - // These should be in the room state table. - stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) - - if err != nil { - return nil, err - } - if len(stateEvents) != len(missing) { - logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing)) - - // TODO: Why is this happening? It's probably the roomserver. Uncomment - // this error again when we work out what it is and fix it, otherwise we - // just end up returning lots of 500s to the client and that breaks - // pretty much everything, rather than just sending what we have. - //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) - } - events = append(events, stateEvents...) - return events, nil -} - -// GetStateDeltas returns the state deltas between fromPos and toPos, -// exclusive of oldPos, inclusive of newPos, for the rooms in which -// the user has new membership events. -// A list of joined room IDs is also returned in case the caller needs it. -func (d *Database) GetStateDeltas( - ctx context.Context, device *userapi.Device, - r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { - // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 - // - Get membership list changes for this user in this sync response - // - For each room which has membership list changes: - // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). - // If it is, then we need to send the full room state down (and 'limited' is always true). - // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. - // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. - // - Get all CURRENTLY joined rooms, and add them to 'joined' block. - txn, err := d.readOnlySnapshot(ctx) - if err != nil { - return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) - } - var succeeded bool - defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) - - // Look up all memberships for the user. We only care about rooms that a - // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - allRoomIDs := make([]string, 0, len(memberships)) - joinedRoomIDs := make([]string, 0, len(memberships)) - for roomID, membership := range memberships { - allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { - joinedRoomIDs = append(joinedRoomIDs, roomID) - } - } - - // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - // find out which rooms this user is peeking, if any. - // We do this before joins so any peeks get overwritten - peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil && err != sql.ErrNoRows { - return nil, nil, err - } - - // add peek blocks - for _, peek := range peeks { - if peek.New { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) - if err != nil { - if err == sql.ErrNoRows { - continue - } - return nil, nil, err - } - state[peek.RoomID] = s - } - if !peek.Deleted { - deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), - RoomID: peek.RoomID, - }) - } - } - - // handle newly joined rooms and non-joined rooms - newlyJoinedRooms := make(map[string]bool, len(state)) - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership == gomatrixserverlib.Join && prevMembership != membership { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) - if err != nil { - if err == sql.ErrNoRows { - continue - } - return nil, nil, err - } - state[roomID] = s - newlyJoinedRooms[roomID] = true - continue // we'll add this room in when we do joined rooms - } - - deltas = append(deltas, types.StateDelta{ - Membership: membership, - MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - RoomID: roomID, - }) - break - } - } - } - - // Add in currently joined rooms - for _, joinedRoomID := range joinedRoomIDs { - deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), - RoomID: joinedRoomID, - NewlyJoined: newlyJoinedRooms[joinedRoomID], - }) - } - - succeeded = true - return deltas, joinedRoomIDs, nil -} - -// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync -// requests with full_state=true. -// Fetches full state for all joined rooms and uses selectStateInRange to get -// updates for other rooms. -func (d *Database) GetStateDeltasForFullStateSync( - ctx context.Context, device *userapi.Device, - r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StateDelta, []string, error) { - txn, err := d.readOnlySnapshot(ctx) - if err != nil { - return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) - } - var succeeded bool - defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) - - // Look up all memberships for the user. We only care about rooms that a - // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - allRoomIDs := make([]string, 0, len(memberships)) - joinedRoomIDs := make([]string, 0, len(memberships)) - for roomID, membership := range memberships { - allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { - joinedRoomIDs = append(joinedRoomIDs, roomID) - } - } - - // Use a reasonable initial capacity - deltas := make(map[string]types.StateDelta) - - peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil && err != sql.ErrNoRows { - return nil, nil, err - } - - // Add full states for all peeking rooms - for _, peek := range peeks { - if !peek.Deleted { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) - if stateErr != nil { - if stateErr == sql.ErrNoRows { - continue - } - return nil, nil, stateErr - } - deltas[peek.RoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, s), - RoomID: peek.RoomID, - } - } - } - - // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. - deltas[roomID] = types.StateDelta{ - Membership: membership, - MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - RoomID: roomID, - } - } - - break - } - } - } - - // Add full states for all joined rooms - for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) - if stateErr != nil { - if stateErr == sql.ErrNoRows { - continue - } - return nil, nil, stateErr - } - deltas[joinedRoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, s), - RoomID: joinedRoomID, - } - } - - // Create a response array. - result := make([]types.StateDelta, len(deltas)) - i := 0 - for _, delta := range deltas { - result[i] = delta - i++ - } - - succeeded = true - return result, joinedRoomIDs, nil -} - -func (d *Database) currentStateStreamEventsForRoom( - ctx context.Context, txn *sql.Tx, roomID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StreamEvent, error) { - allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil) - if err != nil { - return nil, err - } - s := make([]types.StreamEvent, len(allState)) - for i := 0; i < len(s); i++ { - s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} - } - return s, nil -} - -func (d *Database) StoreNewSendForDeviceMessage( - ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, -) (newPos types.StreamPosition, err error) { - j, err := json.Marshal(event) - if err != nil { - return 0, err - } - // Delegate the database write task to the SendToDeviceWriter. It'll guarantee - // that we don't lock the table for writes in more than one place. - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - newPos, err = d.SendToDevice.InsertSendToDeviceMessage( - ctx, txn, userID, deviceID, string(j), - ) - return err - }) - if err != nil { - return 0, err - } - return newPos, nil -} - -func (d *Database) SendToDeviceUpdatesForSync( - ctx context.Context, - userID, deviceID string, - from, to types.StreamPosition, -) (types.StreamPosition, []types.SendToDeviceEvent, error) { - // First of all, get our send-to-device updates for this user. - lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to) - if err != nil { - return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) - } - // If there's nothing to do then stop here. - if len(events) == 0 { - return to, nil, nil - } - return lastPos, events, nil -} - -func (d *Database) CleanSendToDeviceUpdates( - ctx context.Context, - userID, deviceID string, before types.StreamPosition, -) (err error) { - if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before) - }); err != nil { - logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID) - return err - } - return nil -} - -// getMembershipFromEvent returns the value of content.membership iff the event is a state event -// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { - if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { - return "", "" - } - membership, err := ev.Membership() - if err != nil { - return "", "" - } - prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str - return membership, prevMembership -} - -// StoreReceipt stores user receipts -func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) - return err - }) - return -} - -func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { - _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) - return receipts, err -} - -func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount) - return err - }) - return -} - -func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) -} - -func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) -} - -func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) -} -func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) -} - -func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { - return d.Ignores.SelectIgnores(ctx, nil, userID) -} - -func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores) - }) -} - -func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { - var pos types.StreamPosition - var err error - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync) - return nil - }) - return pos, err -} - -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) -} - -func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return d.Presence.GetPresenceAfter(ctx, nil, after, filter) -} - -func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return d.Presence.GetMaxPresenceID(ctx, nil) -} - -func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { - return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) -} diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 21a16dcd3..d8967113a 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -91,14 +91,14 @@ func (s *accountDataStatements) InsertAccountData( } func (s *accountDataStatements) SelectAccountDataInRange( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID string, r types.Range, filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) stmt, params, err := prepareWithFilters( - s.db, nil, selectAccountDataInRangeSQL, + s.db, txn, selectAccountDataInRangeSQL, []interface{}{ userID, r.Low(), r.High(), }, diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index c5674dded..3a5fd6be3 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -82,9 +82,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID) if err != nil { return } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index bd1271dd6..ff45e786e 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -88,12 +88,7 @@ const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" const selectEventsWithEventIDsSQL = "" + - // TODO: The session_id and transaction_id blanks are here because - // the rowsToStreamEvents expects there to be exactly seven columns. We need to - // figure out if these really need to be in the DB, and if so, we need a - // better permanent fix for this. - neilalexander, 2 Jan 2020 - "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" + - " FROM syncapi_current_room_state WHERE event_id IN ($1)" + "SELECT event_id, added_at, headered_event_json, history_visibility FROM syncapi_current_room_state WHERE event_id IN ($1)" const selectSharedUsersSQL = "" + "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" + @@ -163,9 +158,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -187,7 +182,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) params := make([]interface{}, 0, len(roomIDs)) @@ -200,7 +195,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( } defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed") - rows, err := stmt.QueryContext(ctx, params...) + rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) if err != nil { return nil, err } @@ -367,12 +362,18 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( for start < len(eventIDs) { n := minOfInts(len(eventIDs)-start, 999) query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1) - rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + var rows *sql.Rows + var err error + if txn == nil { + rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } else { + rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } if err != nil { return nil, err } start = start + n - events, err := rowsToStreamEvents(rows) + events, err := currentRoomStateRowsToStreamEvents(rows) internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") if err != nil { return nil, err @@ -382,6 +383,35 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( return res, nil } +func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var events []types.StreamEvent + for rows.Next() { + var ( + eventID string + streamPos types.StreamPosition + eventBytes []byte + historyVisibility gomatrixserverlib.HistoryVisibility + ) + if err := rows.Scan(&eventID, &streamPos, &eventBytes, &historyVisibility); err != nil { + return nil, err + } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + ev.Visibility = historyVisibility + + events = append(events, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + }) + } + + return events, nil +} + func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { result := []*gomatrixserverlib.HeaderedEvent{} for rows.Next() { @@ -401,9 +431,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { @@ -429,10 +459,17 @@ func (s *currentRoomStateStatements) SelectSharedUsers( params[k+1] = v } + var provider sqlutil.QueryProvider + if txn == nil { + provider = s.db + } else { + provider = txn + } + result := make([]string, 0, len(otherUserIDs)) query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1) err := sqlutil.RunLimitedVariablesQuery( - ctx, query, s.db, params, sqlutil.SQLite3MaxVariables, + ctx, query, provider, params, sqlutil.SQLite3MaxVariables, func(rows *sql.Rows) error { var stateKey string for rows.Next() { diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 6081a48b1..5f1e980eb 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -77,11 +78,11 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { return err } @@ -94,7 +95,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string @@ -115,8 +116,9 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) + err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext( + ctx, localpart, filterJSON, + ).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { return "", err } @@ -126,7 +128,7 @@ func (s *filterStatements) InsertFilter( } // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + res, err := sqlutil.TxStmt(txn, s.insertFilterStmt).ExecContext(ctx, filterJSON, localpart) if err != nil { return "", err } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 58ab8461e..e2dbcd5c8 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -50,7 +50,7 @@ const deleteInviteEventSQL = "" + "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false" const selectInviteEventsInRangeSQL = "" + - "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC" @@ -132,23 +132,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { - return nil, nil, err + return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]*gomatrixserverlib.HeaderedEvent{} retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( + id types.StreamPosition roomID string eventJSON []byte deleted bool ) - if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { - return nil, nil, err + if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil { + return nil, nil, lastPos, err + } + if id > lastPos { + lastPos = id } // if we have seen this room before, it has a higher stream position and hence takes priority @@ -161,15 +166,19 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { - return nil, nil, err + return nil, nil, lastPos, err } + if deleted { retired[roomID] = event } else { result[roomID] = event } } - return result, retired, nil + if lastPos == 0 { + lastPos = r.To + } + return result, retired, lastPos, nil } func (s *inviteEventsStatements) SelectMaxInviteID( diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 66d4d4381..6242898e1 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t } r := ¬ificationDataStatements{ streamIDStatements: streamID, + db: db, } return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } type notificationDataStatements struct { + db *sql.DB streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + //selectUserUnreadCountsForRooms *sql.Stmt } const notificationDataSchema = ` @@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ ON CONFLICT (user_id, room_id) DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id IN ($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -81,20 +82,31 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + params := make([]interface{}, len(roomIDs)+1) + params[0] = userID + for i := range roomIDs { + params[i+1] = roomIDs[i] + } + sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + prep, err := r.db.PrepareContext(ctx, sql) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, prep, "SelectUserUnreadCountsForRooms: prep.close() failed") + rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1626e32ef..165943027 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -115,6 +115,8 @@ const selectContextAfterEventSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -125,6 +127,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + //selectSearchStmt *sql.Stmt - prepared at runtime } func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { @@ -157,15 +160,16 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID()) return err } @@ -628,3 +632,40 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ } return } + +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + params := make([]interface{}, len(types)) + for i := range types { + params[i] = types[i] + } + params = append(params, afterID) + params = append(params, limit) + selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1) + + stmt, err := s.db.Prepare(selectSQL) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed") + rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed") + + var eventID string + var id int64 + result := make(map[int64]gomatrixserverlib.HeaderedEvent) + for rows.Next() { + var ev gomatrixserverlib.HeaderedEvent + var eventBytes []byte + if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { + return nil, err + } + if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + result[id] = ev + } + return result, rows.Err() +} diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index b2fb77417..81b264988 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -176,9 +176,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (topoPos types.StreamPosition, err error) { if backwardOrdering { - err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } else { - err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } return } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 5ee86448c..4ef51b103 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -172,9 +172,9 @@ func (s *peekStatements) SelectPeeksInRange( } func (s *peekStatements) SelectPeekingDevices( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 31adb005b..a4a9b4395 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -108,7 +108,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) var lastPos types.StreamPosition params := make([]interface{}, len(roomIDs)+1) @@ -116,7 +116,12 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs for k, v := range roomIDs { params[k+1] = v } - rows, err := r.db.QueryContext(ctx, selectSQL, params...) + prep, err := r.db.Prepare(selectSQL) + if err != nil { + return 0, nil, fmt.Errorf("unable to prepare statement: %w", err) + } + defer internal.CloseAndLogIfError(ctx, prep, "SelectRoomReceiptsAfter: prep.close() failed") + rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index a84e2bd16..0879030a6 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -49,6 +49,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) return &d, nil } +func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) { + return &shared.DatabaseTransaction{ + Database: &d.Database, + // not setting a transaction because SQLite doesn't support it + }, nil +} + +func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) { + return &shared.DatabaseTransaction{ + Database: &d.Database, + // not setting a transaction because SQLite doesn't support it + }, nil +} + func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) { if err = d.streamID.Prepare(d.db); err != nil { return err diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index a62818e9b..5ff185a32 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) { }) } +func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) { + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + t.Fatal(err) + } + f(snapshot) + if err := snapshot.Rollback(); err != nil { + t.Fatal(err) + } +} + // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -79,10 +90,13 @@ func TestRecentEventsPDU(t *testing.T) { // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) - latest, err := db.MaxStreamPositionForPDUs(ctx) - if err != nil { - t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) - } + var latest types.StreamPosition + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + var err error + if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil { + t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err) + } + }) testCases := []struct { Name string @@ -140,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { var filter gomatrixserverlib.RoomEventFilter + var gotEvents []types.StreamEvent + var limited bool filter.Limit = tc.Limit - gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ - From: tc.From, - To: tc.To, - }, &filter, !tc.ReverseOrder, true) - if err != nil { - st.Fatalf("failed to do sync: %s", err) - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + var err error + gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ + From: tc.From, + To: tc.To, + }, &filter, !tc.ReverseOrder, true) + if err != nil { + st.Fatalf("failed to do sync: %s", err) + } + }) if limited != tc.WantLimited { st.Errorf("got limited=%v want %v", limited, tc.WantLimited) } @@ -178,22 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { events := r.Events() _ = MustWriteEvents(t, db, events) - from, err := db.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } - t.Logf("max topo pos = %+v", from) - // head towards the beginning of time - to := types.TopologyToken{} + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} - // backpaginate 5 messages starting at the latest position. - filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) - if err != nil { - t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) - } - gots := db.StreamEventsToEvents(nil, paginatedEvents) - test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + // backpaginate 5 messages starting at the latest position. + filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) + if err != nil { + t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) + } + gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) + test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + }) }) } @@ -414,13 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) { defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. - _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("first call should have no updates") - } + + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("first call should have no updates") + } + }) // Try sending a message. streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{ @@ -432,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) { t.Fatal(err) } - // At this point we should get exactly one message. We're sending the sync position - // that we were given from the update and the send-to-device update will be updated - // in the database to reflect that this was the sync position we sent the message at. - streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) - if err != nil { - t.Fatal(err) - } - if count := len(events); count != 1 { - t.Fatalf("second call should have one update, got %d", count) - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update and the send-to-device update will be updated + // in the database to reflect that this was the sync position we sent the message at. + var events []types.SendToDeviceEvent + streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) + if err != nil { + t.Fatal(err) + } + if count := len(events); count != 1 { + t.Fatalf("second call should have one update, got %d", count) + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 { + t.Fatal("third call should have one update still") + } + }) - // At this point we should still have one message because we haven't progressed the - // sync position yet. This is equivalent to the client failing to /sync and retrying - // with the same position. - streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) - if err != nil { - t.Fatal(err) - } - if len(events) != 1 { - t.Fatal("third call should have one update still") - } err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos) if err != nil { return } - // At this point we should now have no updates, because we've progressed the sync - // position. Therefore the update from before will not be sent again. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("fourth call should have no updates") - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will not be sent again. + var events []types.SendToDeviceEvent + _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("fourth call should have no updates") + } - // At this point we should still have no updates, because no new updates have been - // sent. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("fifth call should have no updates") - } + // At this point we should still have no updates, because no new updates have been + // sent. + _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("fifth call should have no updates") + } + }) // Send some more messages and verify the ordering is correct ("in order of arrival") var lastPos types.StreamPosition = 0 @@ -492,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) { lastPos = streamPos } - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) - if err != nil { - t.Fatalf("unable to get events: %v", err) - } - - for i := 0; i < 10; i++ { - want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)) - got := events[i].Content - if !bytes.Equal(got, want) { - t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got)) + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) + if err != nil { + t.Fatalf("unable to get events: %v", err) } - } + + for i := 0; i < 10; i++ { + want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)) + got := events[i].Content + if !bytes.Equal(got, want) { + t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got)) + } + } + }) }) } diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go new file mode 100644 index 000000000..23287c500 --- /dev/null +++ b/syncapi/storage/tables/current_room_state_test.go @@ -0,0 +1,88 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.CurrentRoomState + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresCurrentRoomStateTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqliteCurrentRoomStateTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestCurrentRoomStateTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newCurrentRoomStateTable(t, dbType) + defer close() + events := room.CurrentState() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for i, ev := range events { + err := tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i)) + if err != nil { + return fmt.Errorf("failed to UpsertRoomState: %w", err) + } + } + wantEventIDs := []string{ + events[0].EventID(), events[1].EventID(), events[2].EventID(), events[3].EventID(), + } + gotEvents, err := tab.SelectEventsWithEventIDs(ctx, txn, wantEventIDs) + if err != nil { + return fmt.Errorf("failed to SelectEventsWithEventIDs: %w", err) + } + if len(gotEvents) != len(wantEventIDs) { + return fmt.Errorf("SelectEventsWithEventIDs\ngot %d, want %d results", len(gotEvents), len(wantEventIDs)) + } + gotEventIDs := make(map[string]struct{}, len(gotEvents)) + for _, event := range gotEvents { + if event.ExcludeFromSync { + return fmt.Errorf("SelectEventsWithEventIDs ExcludeFromSync should be false for current room state event %+v", event) + } + gotEventIDs[event.EventID()] = struct{}{} + } + for _, id := range wantEventIDs { + if _, ok := gotEventIDs[id]; !ok { + return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id) + } + } + return nil + }) + if err != nil { + t.Fatalf("err: %v", err) + } + }) +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 193881b44..2fdc3cfbb 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -28,7 +28,7 @@ import ( type AccountData interface { InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. - SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) + SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -37,7 +37,7 @@ type Invites interface { DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error) // SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value // for the room. - SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error) + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -46,7 +46,7 @@ type Peeks interface { DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error) SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) - SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) + SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -68,13 +68,14 @@ type Events interface { // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) - UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error + UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } // Topology keeps track of the depths and stream positions for all events. @@ -97,7 +98,7 @@ type Topology interface { } type CurrentRoomState interface { - SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error @@ -109,9 +110,9 @@ type CurrentRoomState interface { // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. - SelectJoinedUsers(ctx context.Context) (map[string][]string, error) + SelectJoinedUsers(ctx context.Context, txn *sql.Tx) (map[string][]string, error) // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. - SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) + SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) } @@ -141,7 +142,7 @@ type BackwardsExtremities interface { // InsertsBackwardExtremity inserts a new backwards extremity. InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error) // SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids. - SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) + SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) } @@ -171,13 +172,13 @@ type SendToDevice interface { } type Filter interface { - SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error - InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) + SelectFilter(ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string) error + InsertFilter(ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) } type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) - SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) + SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -190,7 +191,7 @@ type Memberships interface { type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 9c19b846b..3f2f7d134 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -3,23 +3,27 @@ package streams import ( "context" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type AccountDataStreamProvider struct { - StreamProvider + DefaultStreamProvider userAPI userapi.SyncUserAPI } -func (p *AccountDataStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *AccountDataStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForAccountData(context.Background()) + id, err := snapshot.MaxStreamPositionForAccountData(ctx) if err != nil { panic(err) } @@ -28,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() { func (p *AccountDataStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *AccountDataStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -43,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( To: to, } - dataTypes, pos, err := p.DB.GetAccountDataInRange( + dataTypes, pos, err := snapshot.GetAccountDataInRange( ctx, req.Device.UserID, r, &req.Filter.AccountData, ) if err != nil { diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 5448ee5bd..7996c2038 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -6,17 +6,19 @@ import ( keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type DeviceListStreamProvider struct { - StreamProvider + DefaultStreamProvider rsAPI api.SyncRoomserverAPI keyAPI keyapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { return p.LatestPosition(ctx) @@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync( func (p *DeviceListStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 925da32f2..17b3b8434 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -9,20 +9,23 @@ import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type InviteStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *InviteStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *InviteStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForInvites(context.Background()) + id, err := snapshot.MaxStreamPositionForInvites(ctx) if err != nil { panic(err) } @@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() { func (p *InviteStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *InviteStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync( To: to, } - invites, retiredInvites, err := p.DB.InviteEventsInRange( + invites, retiredInvites, maxID, err := snapshot.InviteEventsInRange( ctx, req.Device.UserID, r, ) if err != nil { @@ -86,5 +91,5 @@ func (p *InviteStreamProvider) IncrementalSync( } } - return to + return maxID } diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 8ba9e07ca..5a81fd09a 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -3,17 +3,23 @@ package streams import ( "context" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type NotificationDataStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *NotificationDataStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *NotificationDataStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForNotificationData(ctx) if err != nil { panic(err) } @@ -22,34 +28,39 @@ func (p *NotificationDataStreamProvider) Setup() { func (p *NotificationDataStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *NotificationDataStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, - from, to types.StreamPosition, + from, _ types.StreamPosition, ) types.StreamPosition { - // We want counts for all possible rooms, so always start from zero. - countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + // Get the unread notifications for rooms in our join response. + // This is to ensure clients always have an unread notification section + // and can display the correct numbers. + countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { - req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") return from } - // We're merely decorating existing rooms. Note that the Join map - // values are not pointers. + // We're merely decorating existing rooms. for roomID, jr := range req.Response.Rooms.Join { counts := countsByRoom[roomID] if counts == nil { continue } - - jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount - jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + jr.UnreadNotifications = &types.UnreadNotifications{ + HighlightCount: counts.UnreadHighlightCount, + NotificationCount: counts.UnreadNotificationCount, + } req.Response.Rooms.Join[roomID] = jr } - return to + + return p.LatestPosition(ctx) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0ab6de886..d252265ff 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "sort" - "sync" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -18,7 +17,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "go.uber.org/atomic" "github.com/matrix-org/dendrite/syncapi/notifier" ) @@ -33,44 +31,23 @@ const PDU_STREAM_WORKERS = 256 const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8 type PDUStreamProvider struct { - StreamProvider + DefaultStreamProvider - tasks chan func() - workers atomic.Int32 // userID+deviceID -> lazy loading cache lazyLoadCache caching.LazyLoadCache rsAPI roomserverAPI.SyncRoomserverAPI notifier *notifier.Notifier } -func (p *PDUStreamProvider) worker() { - defer p.workers.Dec() - for { - select { - case f := <-p.tasks: - f() - case <-time.After(time.Second * 10): - return - } - } -} - -func (p *PDUStreamProvider) queue(f func()) { - if p.workers.Load() < PDU_STREAM_WORKERS { - p.workers.Inc() - go p.worker() - } - p.tasks <- f -} - -func (p *PDUStreamProvider) Setup() { - p.StreamProvider.Setup() - p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE) +func (p *PDUStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForPDUs(context.Background()) + id, err := snapshot.MaxStreamPositionForPDUs(ctx) if err != nil { panic(err) } @@ -79,6 +56,7 @@ func (p *PDUStreamProvider) Setup() { func (p *PDUStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { from := types.StreamPosition(0) @@ -94,7 +72,7 @@ func (p *PDUStreamProvider) CompleteSync( } // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) + joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) if err != nil { req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") return from @@ -103,7 +81,7 @@ func (p *PDUStreamProvider) CompleteSync( stateFilter := req.Filter.Room.State eventFilter := req.Filter.Room.Timeline - if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { + if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil { req.Log.WithError(err).Error("unable to update event filter with ignored users") } @@ -117,33 +95,23 @@ func (p *PDUStreamProvider) CompleteSync( } // Build up a /sync response. Add joined rooms. - var reqMutex sync.Mutex - var reqWaitGroup sync.WaitGroup - reqWaitGroup.Add(len(joinedRoomIDs)) - for _, room := range joinedRoomIDs { - roomID := room - p.queue(func() { - defer reqWaitGroup.Done() - - jr, jerr := p.getJoinResponseForCompleteSync( - ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, - ) - if jerr != nil { - req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") - return + for _, roomID := range joinedRoomIDs { + jr, jerr := p.getJoinResponseForCompleteSync( + ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, + ) + if jerr != nil { + req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") + if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { + return from } - - reqMutex.Lock() - defer reqMutex.Unlock() - req.Response.Rooms.Join[roomID] = *jr - req.Rooms[roomID] = gomatrixserverlib.Join - }) + continue + } + req.Response.Rooms.Join[roomID] = *jr + req.Rooms[roomID] = gomatrixserverlib.Join } - reqWaitGroup.Wait() - // Add peeked rooms. - peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) + peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) if err != nil { req.Log.WithError(err).Error("p.DB.PeeksInRange failed") return from @@ -152,11 +120,14 @@ func (p *PDUStreamProvider) CompleteSync( if !peek.Deleted { var jr *types.JoinResponse jr, err = p.getJoinResponseForCompleteSync( - ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, + ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") - return from + if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { + return from + } + continue } req.Response.Rooms.Peek[peek.RoomID] = *jr } @@ -167,6 +138,7 @@ func (p *PDUStreamProvider) CompleteSync( func (p *PDUStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) (newPos types.StreamPosition) { @@ -184,14 +156,14 @@ func (p *PDUStreamProvider) IncrementalSync( eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") - return + return from } } else { - if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") - return + return from } } @@ -203,7 +175,7 @@ func (p *PDUStreamProvider) IncrementalSync( return to } - if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { + if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil { req.Log.WithError(err).Error("unable to update event filter with ignored users") } @@ -222,9 +194,12 @@ func (p *PDUStreamProvider) IncrementalSync( } } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") - return to + if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { + return newPos + } + continue } // Reset the position, as it is only for the special case of newly joined rooms if delta.NewlyJoined { @@ -244,6 +219,7 @@ func (p *PDUStreamProvider) IncrementalSync( // nolint:gocyclo func (p *PDUStreamProvider) addRoomDeltaToResponse( ctx context.Context, + snapshot storage.DatabaseTransaction, device *userapi.Device, r types.Range, delta types.StateDelta, @@ -260,7 +236,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // This is all "okay" assuming history_visibility == "shared" which it is by default. r.To = delta.MembershipPos } - recentStreamEvents, limited, err := p.DB.RecentEvents( + recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, delta.RoomID, r, eventFilter, true, true, ) @@ -270,9 +246,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } - recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back - prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) + prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents) if err != nil { return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) } @@ -291,7 +267,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( latestPosition := r.To updateLatestPosition := func(mostRecentEventID string) { var pos types.StreamPosition - if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { + if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil { switch { case r.Backwards && pos < latestPosition: fallthrough @@ -303,7 +279,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( if stateFilter.LazyLoadMembers { delta.StateEvents, err = p.lazyLoadMembers( - ctx, delta.RoomID, true, limited, stateFilter, + ctx, snapshot, delta.RoomID, true, limited, stateFilter, device, recentEvents, delta.StateEvents, ) if err != nil && err != sql.ErrNoRows { @@ -320,7 +296,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } // Applies the history visibility rules - events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) + events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -336,7 +312,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) + p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -376,7 +352,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( ctx context.Context, - db storage.Database, + snapshot storage.DatabaseTransaction, rsAPI roomserverAPI.SyncRoomserverAPI, roomID, userID string, limit int, @@ -384,7 +360,7 @@ func applyHistoryVisibilityFilter( ) ([]*gomatrixserverlib.HeaderedEvent, error) { // We need to make sure we always include the latest states events, if they are in the timeline. // We grep at least limit * 2 events, to ensure we really get the needed events. - stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) + stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) if err != nil { // Not a fatal error, we can continue without the stateEvents, // they are only needed if there are state events in the timeline. @@ -395,7 +371,7 @@ func applyHistoryVisibilityFilter( alwaysIncludeIDs[ev.EventID()] = struct{}{} } startTime := time.Now() - events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") + events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { return nil, err } @@ -408,10 +384,10 @@ func applyHistoryVisibilityFilter( return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { +func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { // Work out how many members are in the room. - joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) + joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) + invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.InvitedMemberCount = &invitedCount @@ -439,7 +415,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe } } } - heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) + heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) if err != nil { return } @@ -449,6 +425,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, roomID string, r types.Range, stateFilter *gomatrixserverlib.StateFilter, @@ -460,7 +437,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, limited, err := p.DB.RecentEvents( + recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, roomID, r, eventFilter, true, true, ) if err != nil { @@ -484,7 +461,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } } - stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) + stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) if err != nil { return } @@ -494,7 +471,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( var prevBatch *types.TopologyToken if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID()) + backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID()) if err != nil { return } @@ -505,18 +482,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From) + p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. - recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) events := recentEvents // Only apply history visibility checks if the response is for joined rooms if !isPeek { - events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) + events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -530,7 +507,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( if err != nil { return nil, err } - stateEvents, err = p.lazyLoadMembers(ctx, roomID, + stateEvents, err = p.lazyLoadMembers( + ctx, snapshot, roomID, false, limited, stateFilter, device, recentEvents, stateEvents, ) @@ -549,7 +527,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } func (p *PDUStreamProvider) lazyLoadMembers( - ctx context.Context, roomID string, + ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, device *userapi.Device, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, @@ -598,7 +576,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( filter.Limit = stateFilter.Limit filter.Senders = &wantUsers filter.Types = &[]string{gomatrixserverlib.MRoomMember} - memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter) + memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return stateEvents, err } @@ -612,8 +590,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. -func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { - ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID) +func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { + ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID) if err != nil { if err == sql.ErrNoRows { return nil diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 15db4d30e..8b87af452 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -23,20 +23,26 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type PresenceStreamProvider struct { - StreamProvider + DefaultStreamProvider // cache contains previously sent presence updates to avoid unneeded updates cache sync.Map notifier *notifier.Notifier } -func (p *PresenceStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *PresenceStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForPresence(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForPresence(ctx) if err != nil { panic(err) } @@ -45,18 +51,20 @@ func (p *PresenceStreamProvider) Setup() { func (p *PresenceStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *PresenceStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { // We pull out a larger number than the filter asks for, since we're filtering out events later - presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) + presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -84,9 +92,10 @@ func (p *PresenceStreamProvider) IncrementalSync( } // Bear in mind that this might return nil, but at least populating // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) + presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) if err != nil { req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() return from } if len(presences) > req.Filter.Presence.Limit { diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index f4e84c7d0..8818a5533 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -4,18 +4,24 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) type ReceiptStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *ReceiptStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *ReceiptStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForReceipts(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForReceipts(ctx) if err != nil { panic(err) } @@ -24,13 +30,15 @@ func (p *ReceiptStreamProvider) Setup() { func (p *ReceiptStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *ReceiptStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -41,7 +49,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } } - lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from) + lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from) if err != nil { req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed") return from diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go index 31c6187cb..00b67cc42 100644 --- a/syncapi/streams/stream_sendtodevice.go +++ b/syncapi/streams/stream_sendtodevice.go @@ -3,17 +3,23 @@ package streams import ( "context" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type SendToDeviceStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *SendToDeviceStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *SendToDeviceStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(ctx) if err != nil { panic(err) } @@ -22,18 +28,20 @@ func (p *SendToDeviceStreamProvider) Setup() { func (p *SendToDeviceStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *SendToDeviceStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { // See if we have any new tasks to do for the send-to-device messaging. - lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) + lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) if err != nil { req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") return from diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go index f781065be..a6f7c7a06 100644 --- a/syncapi/streams/stream_typing.go +++ b/syncapi/streams/stream_typing.go @@ -5,24 +5,27 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) type TypingStreamProvider struct { - StreamProvider + DefaultStreamProvider EDUCache *caching.EDUCache } func (p *TypingStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *TypingStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { diff --git a/syncapi/streams/streamprovider.go b/syncapi/streams/streamprovider.go new file mode 100644 index 000000000..8b12e2eba --- /dev/null +++ b/syncapi/streams/streamprovider.go @@ -0,0 +1,28 @@ +package streams + +import ( + "context" + + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" +) + +type StreamProvider interface { + Setup(ctx context.Context, snapshot storage.DatabaseTransaction) + + // Advance will update the latest position of the stream based on + // an update and will wake callers waiting on StreamNotifyAfter. + Advance(latest types.StreamPosition) + + // CompleteSync will update the response to include all updates as needed + // for a complete sync. It will always return immediately. + CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition + + // IncrementalSync will update the response to include all updates between + // the from and to sync positions. It will always return immediately, + // making no changes if the range contains no updates. + IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition + + // LatestPosition returns the latest stream position for this stream. + LatestPosition(ctx context.Context) types.StreamPosition +} diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index dbc053bd8..dc8547621 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -4,6 +4,7 @@ import ( "context" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/notifier" @@ -13,15 +14,15 @@ import ( ) type Streams struct { - PDUStreamProvider types.StreamProvider - TypingStreamProvider types.StreamProvider - ReceiptStreamProvider types.StreamProvider - InviteStreamProvider types.StreamProvider - SendToDeviceStreamProvider types.StreamProvider - AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.StreamProvider - NotificationDataStreamProvider types.StreamProvider - PresenceStreamProvider types.StreamProvider + PDUStreamProvider StreamProvider + TypingStreamProvider StreamProvider + ReceiptStreamProvider StreamProvider + InviteStreamProvider StreamProvider + SendToDeviceStreamProvider StreamProvider + AccountDataStreamProvider StreamProvider + DeviceListStreamProvider StreamProvider + NotificationDataStreamProvider StreamProvider + PresenceStreamProvider StreamProvider } func NewSyncStreamProviders( @@ -31,52 +32,61 @@ func NewSyncStreamProviders( ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - lazyLoadCache: lazyLoadCache, - rsAPI: rsAPI, - notifier: notifier, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + lazyLoadCache: lazyLoadCache, + rsAPI: rsAPI, + notifier: notifier, }, TypingStreamProvider: &TypingStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - EDUCache: eduCache, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + EDUCache: eduCache, }, ReceiptStreamProvider: &ReceiptStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, InviteStreamProvider: &InviteStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, AccountDataStreamProvider: &AccountDataStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - userAPI: userAPI, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + userAPI: userAPI, }, NotificationDataStreamProvider: &NotificationDataStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, DeviceListStreamProvider: &DeviceListStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - rsAPI: rsAPI, - keyAPI: keyAPI, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, + keyAPI: keyAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - notifier: notifier, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + notifier: notifier, }, } - streams.PDUStreamProvider.Setup() - streams.TypingStreamProvider.Setup() - streams.ReceiptStreamProvider.Setup() - streams.InviteStreamProvider.Setup() - streams.SendToDeviceStreamProvider.Setup() - streams.AccountDataStreamProvider.Setup() - streams.NotificationDataStreamProvider.Setup() - streams.DeviceListStreamProvider.Setup() - streams.PresenceStreamProvider.Setup() + ctx := context.TODO() + snapshot, err := d.NewDatabaseSnapshot(ctx) + if err != nil { + panic(err) + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + streams.PDUStreamProvider.Setup(ctx, snapshot) + streams.TypingStreamProvider.Setup(ctx, snapshot) + streams.ReceiptStreamProvider.Setup(ctx, snapshot) + streams.InviteStreamProvider.Setup(ctx, snapshot) + streams.SendToDeviceStreamProvider.Setup(ctx, snapshot) + streams.AccountDataStreamProvider.Setup(ctx, snapshot) + streams.NotificationDataStreamProvider.Setup(ctx, snapshot) + streams.DeviceListStreamProvider.Setup(ctx, snapshot) + streams.PresenceStreamProvider.Setup(ctx, snapshot) + + succeeded = true return streams } diff --git a/syncapi/streams/template_stream.go b/syncapi/streams/template_stream.go index 15074cc10..f208d84e4 100644 --- a/syncapi/streams/template_stream.go +++ b/syncapi/streams/template_stream.go @@ -8,16 +8,18 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -type StreamProvider struct { +type DefaultStreamProvider struct { DB storage.Database latest types.StreamPosition latestMutex sync.RWMutex } -func (p *StreamProvider) Setup() { +func (p *DefaultStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { } -func (p *StreamProvider) Advance( +func (p *DefaultStreamProvider) Advance( latest types.StreamPosition, ) { p.latestMutex.Lock() @@ -28,7 +30,7 @@ func (p *StreamProvider) Advance( } } -func (p *StreamProvider) LatestPosition( +func (p *DefaultStreamProvider) LatestPosition( ctx context.Context, ) types.StreamPosition { p.latestMutex.RLock() diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b2ea105ff..a71d32ab8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -31,6 +31,7 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -305,78 +306,182 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") } + withTransaction := func(from types.StreamPosition, f func(snapshot storage.DatabaseTransaction) types.StreamPosition) types.StreamPosition { + var succeeded bool + snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to acquire database snapshot for sync request") + return from + } + defer func() { + succeeded = err == nil + sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + }() + return f(snapshot) + } + if syncReq.Since.IsEmpty() { // Complete sync syncReq.Response.NextBatch = types.StreamingToken{ // Get the current DeviceListPosition first, as the currentPosition // might advance while processing other streams, resulting in flakey // tests. - DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( - syncReq.Context, syncReq, + DeviceListPosition: withTransaction( + syncReq.Since.DeviceListPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.DeviceListStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - PDUPosition: rp.streams.PDUStreamProvider.CompleteSync( - syncReq.Context, syncReq, + PDUPosition: withTransaction( + syncReq.Since.PDUPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.PDUStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - TypingPosition: rp.streams.TypingStreamProvider.CompleteSync( - syncReq.Context, syncReq, + TypingPosition: withTransaction( + syncReq.Since.TypingPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.TypingStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync( - syncReq.Context, syncReq, + ReceiptPosition: withTransaction( + syncReq.Since.ReceiptPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.ReceiptStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - InvitePosition: rp.streams.InviteStreamProvider.CompleteSync( - syncReq.Context, syncReq, + InvitePosition: withTransaction( + syncReq.Since.InvitePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.InviteStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync( - syncReq.Context, syncReq, + SendToDevicePosition: withTransaction( + syncReq.Since.SendToDevicePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.SendToDeviceStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, + AccountDataPosition: withTransaction( + syncReq.Since.AccountDataPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.AccountDataStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, + NotificationDataPosition: withTransaction( + syncReq.Since.NotificationDataPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.NotificationDataStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), - PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync( - syncReq.Context, syncReq, + PresencePosition: withTransaction( + syncReq.Since.PresencePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.PresenceStreamProvider.CompleteSync( + syncReq.Context, txn, syncReq, + ) + }, ), } } else { // Incremental sync syncReq.Response.NextBatch = types.StreamingToken{ - PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.PDUPosition, currentPos.PDUPosition, + PDUPosition: withTransaction( + syncReq.Since.PDUPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.PDUStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.PDUPosition, currentPos.PDUPosition, + ) + }, ), - TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.TypingPosition, currentPos.TypingPosition, + TypingPosition: withTransaction( + syncReq.Since.TypingPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.TypingStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.TypingPosition, currentPos.TypingPosition, + ) + }, ), - ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, + ReceiptPosition: withTransaction( + syncReq.Since.ReceiptPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.ReceiptStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, + ) + }, ), - InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.InvitePosition, currentPos.InvitePosition, + InvitePosition: withTransaction( + syncReq.Since.InvitePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.InviteStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.InvitePosition, currentPos.InvitePosition, + ) + }, ), - SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, + SendToDevicePosition: withTransaction( + syncReq.Since.SendToDevicePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.SendToDeviceStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, + ) + }, ), - AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, + AccountDataPosition: withTransaction( + syncReq.Since.AccountDataPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.AccountDataStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, + ) + }, ), - NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + NotificationDataPosition: withTransaction( + syncReq.Since.NotificationDataPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.NotificationDataStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + ) + }, ), - DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, + DeviceListPosition: withTransaction( + syncReq.Since.DeviceListPosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.DeviceListStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, + ) + }, ), - PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.PresencePosition, currentPos.PresencePosition, + PresencePosition: withTransaction( + syncReq.Since.PresencePosition, + func(txn storage.DatabaseTransaction) types.StreamPosition { + return rp.streams.PresenceStreamProvider.IncrementalSync( + syncReq.Context, txn, syncReq, + syncReq.Since.PresencePosition, currentPos.PresencePosition, + ) + }, ), } // it's possible for there to be no updates for this user even though since < current pos, @@ -437,15 +542,23 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") return jsonerror.InternalServerError() } - rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition) + snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to acquire database snapshot for key change") + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to DeviceListCatchup info") return jsonerror.InternalServerError() } + succeeded = true return util.JSONResponse{ Code: 200, JSON: struct { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 68537bc45..be19310f2 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -77,16 +77,6 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start presence consumer") } - userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), - } - - userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - } - keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), js, rsAPI, syncDB, notifier, @@ -98,15 +88,15 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, + streams.InviteStreamProvider, rsAPI, base.Fulltext, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } clientConsumer := consumers.NewOutputClientDataConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, - userAPIReadUpdateProducer, + base.ProcessContext, cfg, js, natsClient, syncDB, notifier, + streams.AccountDataStreamProvider, base.Fulltext, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") @@ -135,7 +125,6 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, - userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") @@ -143,6 +132,6 @@ func AddPublicRoutes( routing.Setup( base.PublicClientAPIMux, requestPool, syncDB, userAPI, - rsAPI, cfg, base.Caches, + rsAPI, cfg, base.Caches, base.Fulltext, ) } diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index a9ea234d0..378cafe99 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool { return false } } - -type StreamProvider interface { - Setup() - - // Advance will update the latest position of the stream based on - // an update and will wake callers waiting on StreamNotifyAfter. - Advance(latest StreamPosition) - - // CompleteSync will update the response to include all updates as needed - // for a complete sync. It will always return immediately. - CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition - - // IncrementalSync will update the response to include all updates between - // the from and to sync positions. It will always return immediately, - // making no changes if the range contains no updates. - IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition - - // LatestPosition returns the latest stream position for this stream. - LatestPosition(ctx context.Context) StreamPosition -} diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d75d53ca9..3b85db4a4 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool { len(r.ToDevice.Events) == 0 } +type UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` +} + // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. type JoinResponse struct { Summary struct { @@ -419,10 +424,7 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` - UnreadNotifications struct { - HighlightCount int `json:"highlight_count"` - NotificationCount int `json:"notification_count"` - } `json:"unread_notifications"` + *UnreadNotifications `json:"unread_notifications,omitempty"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -503,19 +505,6 @@ type Peek struct { Deleted bool } -type ReadUpdate struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - Read StreamPosition `json:"read,omitempty"` - FullyRead StreamPosition `json:"fully_read,omitempty"` -} - -// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type StreamedEvent struct { - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - StreamPosition StreamPosition `json:"stream_position"` -} - // OutputReceiptEvent is an entry in the receipt output kafka log type OutputReceiptEvent struct { UserID string `json:"user_id"` diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go new file mode 100644 index 000000000..c220d35cb --- /dev/null +++ b/userapi/consumers/clientapi.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/util" +) + +// OutputReceiptEventConsumer consumes events that originated in the clientAPI. +type OutputReceiptEventConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + serverName gomatrixserverlib.ServerName + syncProducer *producers.SyncAPI + pgClient pushgateway.Client +} + +// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputReceiptEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + syncProducer *producers.SyncAPI, + pgClient pushgateway.Client, +) *OutputReceiptEventConsumer { + return &OutputReceiptEventConsumer{ + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"), + db: store, + serverName: cfg.Matrix.ServerName, + syncProducer: syncProducer, + pgClient: pgClient, + } +} + +// Start consuming receipts events. +func (s *OutputReceiptEventConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userID := msg.Header.Get(jetstream.UserID) + roomID := msg.Header.Get(jetstream.RoomID) + readPos := msg.Header.Get(jetstream.EventID) + evType := msg.Header.Get("type") + + if readPos == "" || evType != "m.read" { + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.serverName { + return true + } + + metadata, err := msg.Metadata() + if err != nil { + return false + } + + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + + if !updated { + return true + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/roomserver.go similarity index 85% rename from userapi/consumers/syncapi_streamevent.go rename to userapi/consumers/roomserver.go index f3b2bf27f..952de98f7 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/roomserver.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/userapi/util" ) -type OutputStreamEventConsumer struct { +type OutputRoomEventConsumer struct { ctx context.Context cfg *config.UserAPI rsAPI rsapi.UserRoomserverAPI @@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct { syncProducer *producers.SyncAPI } -func NewOutputStreamEventConsumer( +func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, @@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer( pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, -) *OutputStreamEventConsumer { - return &OutputStreamEventConsumer{ +) *OutputRoomEventConsumer { + return &OutputRoomEventConsumer{ ctx: process.Context(), cfg: cfg, jetstream: js, db: store, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), + durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), pgClient: pgClient, rsAPI: rsAPI, syncProducer: syncProducer, } } -func (s *OutputStreamEventConsumer) Start() error { +func (s *OutputRoomEventConsumer) Start() error { if err := jetstream.JetStreamConsumer( s.ctx, s.jetstream, s.topic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), @@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error { return nil } -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called - var output types.StreamedEvent - output.Event = &gomatrixserverlib.HeaderedEvent{} + var output rsapi.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { - log.WithError(err).Errorf("userapi consumer: message parse failure") + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true } - if output.Event.Event == nil { + if output.Type != rsapi.OutputTypeNewRoomEvent { + return true + } + event := output.NewRoomEvent.Event + if event == nil { log.Errorf("userapi consumer: expected event") return true } log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), - "event_type": output.Event.Type(), - "stream_pos": output.StreamPosition, - }).Tracef("Received message from sync API: %#v", output) + "event_id": event.EventID(), + "event_type": event.Type(), + }).Tracef("Received message from roomserver: %#v", output) - if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + metadata, err := msg.Metadata() + if err != nil { + return true + } + + if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), + "event_id": event.EventID(), }).WithError(err).Errorf("userapi consumer: process room event failure") } return true } -func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { +func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) @@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g // removing it means we can send all notifications to // e.g. Element's Push gateway in one go. for _, mem := range members { - if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil { log.WithFields(log.Fields{ "localpart": mem.Localpart, - }).WithError(err).Debugf("Unable to push to local user") + }).WithError(err).Error("Unable to push to local user") continue } } @@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, // localRoomMembers fetches the current local members of a room, and // the total number of members. -func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { +func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, @@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID // looks it up in roomserver. If there is no name, // m.room.canonical_alias is consulted. Returns an empty string if the // room has no name. -func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { +func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { if event.Type() == gomatrixserverlib.MRoomName { name, err := unmarshalRoomName(event) if err != nil { @@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er } // notifyLocal finds the right push actions for a local user, given an event. -func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { +func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { return err @@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma "event_id": event.EventID(), "room_id": event.RoomID(), "localpart": mem.Localpart, - }).Debugf("Push rule evaluation rejected the event") + }).Tracef("Push rule evaluation rejected the event") return nil } @@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { return err } @@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma "localpart": mem.Localpart, "num_urls": len(devicesByURLAndFormat), "num_unread": userNumUnreadNotifs, - }).Debugf("Notifying single member") + }).Trace("Notifying single member") // Push gateways are out of our control, and we cannot risk // looking up the server on a misbehaving push gateway. Each user @@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). -func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { +func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { if event.Sender() == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. @@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event "room_id": event.RoomID(), "localpart": mem.Localpart, "rule_id": rule.RuleID, - }).Tracef("Matched a push rule") + }).Trace("Matched a push rule") return rule.Actions, nil } @@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) if err != nil { return nil, "", err @@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp } // notifyHTTP performs a notificatation to a Push Gateway. -func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { +func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { logger := log.WithFields(log.Fields{ "event_id": event.EventID(), "url": url, @@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat } } - logger.Debugf("Notifying push gateway %s", url) + logger.Tracef("Notifying push gateway %s", url) var res pushgateway.NotifyResponse if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { logger.WithError(err).Errorf("Failed to notify push gateway %s", url) return nil, err } - logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result") if len(res.Rejected) == 0 { return nil, nil @@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, diff --git a/userapi/consumers/syncapi_streamevent_test.go b/userapi/consumers/roomserver_test.go similarity index 98% rename from userapi/consumers/syncapi_streamevent_test.go rename to userapi/consumers/roomserver_test.go index 48ea0fe11..3bbeb439a 100644 --- a/userapi/consumers/syncapi_streamevent_test.go +++ b/userapi/consumers/roomserver_test.go @@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputStreamEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db} testCases := []struct { name string diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go deleted file mode 100644 index 54654f757..000000000 --- a/userapi/consumers/syncapi_readupdate.go +++ /dev/null @@ -1,137 +0,0 @@ -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/syncapi/types" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/producers" - "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/dendrite/userapi/util" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -type OutputReadUpdateConsumer struct { - ctx context.Context - cfg *config.UserAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - pgClient pushgateway.Client - ServerName gomatrixserverlib.ServerName - topic string - userAPI uapi.UserInternalAPI - syncProducer *producers.SyncAPI -} - -func NewOutputReadUpdateConsumer( - process *process.ProcessContext, - cfg *config.UserAPI, - js nats.JetStreamContext, - store storage.Database, - pgClient pushgateway.Client, - userAPI uapi.UserInternalAPI, - syncProducer *producers.SyncAPI, -) *OutputReadUpdateConsumer { - return &OutputReadUpdateConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - pgClient: pgClient, - userAPI: userAPI, - syncProducer: syncProducer, - } -} - -func (s *OutputReadUpdateConsumer) Start() error { - if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, 1, - s.onMessage, nats.DeliverAll(), nats.ManualAck(), - ); err != nil { - return err - } - return nil -} - -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var read types.ReadUpdate - if err := json.Unmarshal(msg.Data, &read); err != nil { - log.WithError(err).Error("userapi clientapi consumer: message parse failure") - return true - } - if read.FullyRead == 0 && read.Read == 0 { - return true - } - - userID := string(msg.Header.Get(jetstream.UserID)) - roomID := string(msg.Header.Get(jetstream.RoomID)) - - localpart, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - log.WithError(err).Error("userapi clientapi consumer: SplitID failure") - return true - } - if domain != s.ServerName { - log.Error("userapi clientapi consumer: not a local user") - return true - } - - log := log.WithFields(log.Fields{ - "room_id": roomID, - "user_id": userID, - }) - log.Tracef("Received read update from sync API: %#v", read) - - if read.Read > 0 { - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) - if err != nil { - log.WithError(err).Error("userapi EDU consumer") - return false - } - - if updated { - if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { - log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") - return false - } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") - return false - } - } - } - - if read.FullyRead > 0 { - deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) - if err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") - return false - } - - if deleted { - if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") - return false - } - - if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") - return false - } - } - } - - return true -} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index dcbb73614..591faffd6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -39,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + userapiUtil "github.com/matrix-org/dendrite/userapi/util" ) type UserInternalAPI struct { @@ -51,6 +53,7 @@ type UserInternalAPI struct { AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI + PgClient pushgateway.Client } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc ignoredUsers = &synctypes.IgnoredUsers{} _ = json.Unmarshal(req.AccountData, ignoredUsers) } + if req.DataType == "m.fully_read" { + if err := a.setFullyRead(ctx, req); err != nil { + return err + } + } if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ RoomID: req.RoomID, Type: req.DataType, @@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc return nil } +func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { + var output eventutil.ReadMarkerJSON + + if err := json.Unmarshal(req.AccountData, &output); err != nil { + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") + return nil + } + if domain != a.ServerName { + return nil + } + + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + if err != nil { + logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") + return err + } + + if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") + return err + } + + // nothing changed, no need to notify the push gateway + if !deleted { + return nil + } + + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") + return err + } + return nil +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { @@ -750,11 +796,6 @@ func (a *UserInternalAPI) PerformPushRulesPut( if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { return err } - if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ - Type: pushRulesAccountDataType, - }); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed") - } return nil } diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 27cfc2848..f556ea352 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" ) type JetStreamPublisher interface { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index fbac463e2..02efe7afe 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -119,9 +119,9 @@ type ThreePID interface { } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) + InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index a27c1125e..24a30b2f5 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 2eb379ae4..6fb714fba 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e32a442d0..3ff299f1b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) return err @@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) return err diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index df8260251..a35ec7be5 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index dba97c3d4..4de0a9f06 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a26097338..ca7c1bfd2 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -7,6 +7,11 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" @@ -14,10 +19,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/bcrypt" ) const loginTokenLifetime = time.Minute @@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2fe955670..cc4287997 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -105,9 +105,9 @@ type PusherTable interface { type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/userapi.go b/userapi/userapi.go index 23855a89f..d26b4e19a 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -81,16 +81,17 @@ func NewInternalAPI( KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, + PgClient: pgClient, } - readConsumer := consumers.NewOutputReadUpdateConsumer( - base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + receiptConsumer := consumers.NewOutputReceiptEventConsumer( + base.ProcessContext, cfg, js, db, syncProducer, pgClient, ) - if err := readConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start user API read update consumer") + if err := receiptConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API receipt consumer") } - eventConsumer := consumers.NewOutputStreamEventConsumer( + eventConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, ) if err := eventConsumer.Start(); err != nil {