diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d413a29c..0bcd2bb1e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,34 +20,40 @@ should pick up any unit test and run it). There are also [scripts](scripts) for [linting](scripts/find-lint.sh) and doing a [build/test/lint run](scripts/build-test-lint.sh). +As of February 2020, we are deprecating support for Go 1.11 and Go 1.12 and are +now targeting Go 1.13 or later. Please ensure that you are using at least Go +1.13 when developing for Dendrite - our CI will lint and run tests against this +version. + ## Continuous Integration When a Pull Request is submitted, continuous integration jobs are run -automatically to ensure the code builds and is relatively well-written. The -jobs are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/), -and the Buildkite pipeline configuration can be found in Matrix.org's -[pipelines repository](https://github.com/matrix-org/pipelines). +automatically to ensure the code builds and is relatively well-written. The jobs +are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/), and the +Buildkite pipeline configuration can be found in Matrix.org's [pipelines +repository](https://github.com/matrix-org/pipelines). If a job fails, click the "details" button and you should be taken to the job's logs. -![Click the details button on the failing build step](docs/images/details-button-location.jpg) +![Click the details button on the failing build +step](docs/images/details-button-location.jpg) -Scroll down to the failing step and you should see some log output. Scan -the logs until you find what it's complaining about, fix it, submit a new -commit, then rinse and repeat until CI passes. +Scroll down to the failing step and you should see some log output. Scan the +logs until you find what it's complaining about, fix it, submit a new commit, +then rinse and repeat until CI passes. ### Running CI Tests Locally To save waiting for CI to finish after every commit, it is ideal to run the -checks locally before pushing, fixing errors first. This also saves other -people time as only so many PRs can be tested at a given time. +checks locally before pushing, fixing errors first. This also saves other people +time as only so many PRs can be tested at a given time. -To execute what Buildkite tests, first run `./scripts/build-test-lint.sh`; -this script will build the code, lint it, and run `go test ./...` with race -condition checking enabled. If something needs to be changed, fix it and then -run the script again until it no longer complains. Be warned that the linting -can take a significant amount of CPU and RAM. +To execute what Buildkite tests, first run `./scripts/build-test-lint.sh`; this +script will build the code, lint it, and run `go test ./...` with race condition +checking enabled. If something needs to be changed, fix it and then run the +script again until it no longer complains. Be warned that the linting can take a +significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in @@ -61,16 +67,18 @@ tests. ## Picking Things To Do -If you're new then feel free to pick up an issue labelled [good first issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue). +If you're new then feel free to pick up an issue labelled [good first +issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue). These should be well-contained, small pieces of work that can be picked up to help you get familiar with the code base. Once you're comfortable with hacking on Dendrite there are issues lablled as -[help wanted](https://github.com/matrix-org/dendrite/labels/help%20wanted), these -are often slightly larger or more complicated pieces of work but are hopefully -nonetheless fairly well-contained. +[help wanted](https://github.com/matrix-org/dendrite/labels/help%20wanted), +these are often slightly larger or more complicated pieces of work but are +hopefully nonetheless fairly well-contained. -We ask people who are familiar with Dendrite to leave the [good first issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue) +We ask people who are familiar with Dendrite to leave the [good first +issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue) issues so that there is always a way for new people to come and get involved. ## Getting Help @@ -79,9 +87,11 @@ For questions related to developing on Dendrite we have a dedicated room on Matrix [#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org) where we're happy to help. -For more general questions please use [#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). +For more general questions please use +[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). ## Sign off We ask that everyone who contributes to the project signs off their -contributions, in accordance with the [DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off). +contributions, in accordance with the +[DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off). diff --git a/INSTALL.md b/INSTALL.md index 0fb0c08e5..4173e705e 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -12,7 +12,7 @@ Dendrite can be run in one of two configurations: ## Requirements - - Go 1.11+ + - Go 1.13+ - Postgres 9.5+ - For Kafka (optional if using the monolith server): - Unix-based system (https://kafka.apache.org/documentation/#os) @@ -22,7 +22,7 @@ Dendrite can be run in one of two configurations: ## Setting up a development environment -Assumes Go 1.10+ and JDK 1.8+ are already installed and are on PATH. +Assumes Go 1.13+ and JDK 1.8+ are already installed and are on PATH. ```bash # Get the code @@ -101,7 +101,7 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th It is possible to use 'naffka' as an in-process replacement to Kafka when using the monolith server. To do this, set `use_naffka: true` in `dendrite.yaml` and uncomment -the necessary line related to naffka in the `database` section. Be sure to update the +the necessary line related to naffka in the `database` section. Be sure to update the database username and password if needed. The monolith server can be started as shown below. By default it listens for @@ -255,7 +255,7 @@ you want to support federation. ./bin/dendrite-federation-sender-server --config dendrite.yaml ``` -### Run an appservice server +### Run an appservice server This sends events from the network to [application services](https://matrix.org/docs/spec/application_service/unstable.html) diff --git a/README.md b/README.md index 2dadb1f4f..801d0e3ca 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,30 @@ # Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) -Dendrite will be a matrix homeserver written in go. +Dendrite will be a second-generation Matrix homeserver written in Go. -It's still very much a work in progress, but installation instructions can -be found in [INSTALL.md](INSTALL.md) +It's still very much a work in progress, but installation instructions can be +found in [INSTALL.md](INSTALL.md). It is not recommended to use Dendrite as a +production homeserver at this time. -An overview of the design can be found in [DESIGN.md](DESIGN.md) +An overview of the design can be found in [DESIGN.md](DESIGN.md). # Contributing -Everyone is welcome to help out and contribute! See [CONTRIBUTING.md](CONTRIBUTING.md) -to get started! +Everyone is welcome to help out and contribute! See +[CONTRIBUTING.md](CONTRIBUTING.md) to get started! -We aim to try and make it as easy as possible to jump in. +Please note that, as of February 2020, Dendrite now only targets Go 1.13 or +later. Please ensure that you are using at least Go 1.13 when developing for +Dendrite. # Discussion For questions about Dendrite we have a dedicated room on Matrix -[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). -Development discussion should happen in +[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). Development +discussion should happen in [#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org). # Progress -There's plenty still to do to make Dendrite usable! We're tracking progress in -a [project board](https://github.com/matrix-org/dendrite/projects/2). +There's plenty still to do to make Dendrite usable! We're tracking progress in a +[project board](https://github.com/matrix-org/dendrite/projects/2). diff --git a/appservice/appservice.go b/appservice/appservice.go index 6013b5b33..181799879 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -100,7 +100,7 @@ func SetupAppServiceAPIComponent( // Set up HTTP Endpoints routing.Setup( - base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI, + base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI, accountsDB, federation, transactionsCache, ) diff --git a/appservice/routing/routing.go b/appservice/routing/routing.go index 13f8ed724..42fa80520 100644 --- a/appservice/routing/routing.go +++ b/appservice/routing/routing.go @@ -36,7 +36,7 @@ const pathPrefixApp = "/_matrix/app/v1" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, cfg config.Dendrite, // nolint: unparam + apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam accountDB accounts.Database, // nolint: unparam federation *gomatrixserverlib.FederationClient, // nolint: unparam diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go index 14d9c9d95..d0cfcc0cf 100644 --- a/clientapi/auth/storage/accounts/postgres/account_data_table.go +++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go @@ -90,6 +90,7 @@ func (s *accountDataStatements) selectAccountData( if err != nil { return } + defer rows.Close() // nolint: errcheck global = []gomatrixserverlib.ClientEvent{} rooms = make(map[string][]gomatrixserverlib.ClientEvent) @@ -114,8 +115,7 @@ func (s *accountDataStatements) selectAccountData( global = append(global, ac) } } - - return + return global, rooms, rows.Err() } func (s *accountDataStatements) selectAccountDataByType( diff --git a/clientapi/auth/storage/accounts/postgres/membership_table.go b/clientapi/auth/storage/accounts/postgres/membership_table.go index 24ccff370..426c2d6ac 100644 --- a/clientapi/auth/storage/accounts/postgres/membership_table.go +++ b/clientapi/auth/storage/accounts/postgres/membership_table.go @@ -122,11 +122,10 @@ func (s *membershipStatements) selectMembershipsByLocalpart( for rows.Next() { var m authtypes.Membership m.Localpart = localpart - if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { - return nil, err + if err = rows.Scan(&m.RoomID, &m.EventID); err != nil { + return } memberships = append(memberships, m) } - - return + return memberships, rows.Err() } diff --git a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go index 762bced42..53f6408d1 100644 --- a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go @@ -97,6 +97,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( if err != nil { return } + defer rows.Close() // nolint: errcheck threepids = []authtypes.ThreePID{} for rows.Next() { @@ -110,8 +111,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( Medium: medium, }) } - - return + return threepids, rows.Err() } func (s *threepidStatements) insertThreePID( diff --git a/clientapi/auth/storage/devices/postgres/devices_table.go b/clientapi/auth/storage/devices/postgres/devices_table.go index c27c699e9..349bf1ef7 100644 --- a/clientapi/auth/storage/devices/postgres/devices_table.go +++ b/clientapi/auth/storage/devices/postgres/devices_table.go @@ -19,10 +19,10 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/common" - + "github.com/lib/pq" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -80,6 +80,9 @@ const deleteDeviceSQL = "" + const deleteDevicesByLocalpartSQL = "" + "DELETE FROM device_devices WHERE localpart = $1" +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -88,6 +91,7 @@ type devicesStatements struct { updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt + deleteDevicesStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { + return + } s.serverName = server return } @@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice( }, nil } +// deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { @@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice( return err } +// deleteDevices removes a single or multiple devices by ids and user localpart. +// Returns an error if the execution failed. +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + stmt := common.TxStmt(txn, s.deleteDevicesStmt) + _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + return err +} + +// deleteDevicesByLocalpart removes all devices for the +// given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { @@ -206,6 +226,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( if err != nil { return devices, err } + defer rows.Close() // nolint: errcheck for rows.Next() { var dev authtypes.Device @@ -217,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart( devices = append(devices, dev) } - return devices, nil + return devices, rows.Err() } diff --git a/clientapi/auth/storage/devices/postgres/storage.go b/clientapi/auth/storage/devices/postgres/storage.go index baf9186d6..221c3998e 100644 --- a/clientapi/auth/storage/devices/postgres/storage.go +++ b/clientapi/auth/storage/devices/postgres/storage.go @@ -152,6 +152,21 @@ func (d *Database) RemoveDevice( }) } +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + // RemoveAllDevices revokes devices by deleting the entry in the // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go index d4349c99f..55b8b5f4e 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -68,6 +69,9 @@ const deleteDeviceSQL = "" + const deleteDevicesByLocalpartSQL = "" + "DELETE FROM device_devices WHERE localpart = $1" +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt @@ -77,6 +81,7 @@ type devicesStatements struct { updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt + deleteDevicesStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -109,6 +114,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { + return + } s.serverName = server return } @@ -147,6 +155,14 @@ func (s *devicesStatements) deleteDevice( return err } +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + stmt := common.TxStmt(txn, s.deleteDevicesStmt) + _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + return err +} + func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { diff --git a/clientapi/auth/storage/devices/sqlite3/storage.go b/clientapi/auth/storage/devices/sqlite3/storage.go index 3141617c0..e1ce6f00d 100644 --- a/clientapi/auth/storage/devices/sqlite3/storage.go +++ b/clientapi/auth/storage/devices/sqlite3/storage.go @@ -154,6 +154,21 @@ func (d *Database) RemoveDevice( }) } +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + // RemoveAllDevices revokes devices by deleting the entry in the // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go index 84de573bc..82f756401 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/clientapi/auth/storage/devices/storage.go @@ -17,6 +17,7 @@ type Database interface { CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveAllDevices(ctx context.Context, localpart string) error } diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 854f098c0..bb44e016a 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -67,7 +67,7 @@ func SetupClientAPIComponent( } routing.Setup( - base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI, + base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI, accountsDB, deviceDB, federation, *keyRing, userUpdateProducer, syncProducer, typingProducer, transactionsCache, fedSenderAPI, ) diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index cd4530d1b..5332226c4 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -102,7 +102,7 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s // AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID} func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, - cfg config.Dendrite, + cfg *config.Dendrite, ) *util.JSONResponse { sessionID := req.URL.Query().Get("session") @@ -130,7 +130,7 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { + if err := checkRecaptchaEnabled(cfg, w, req); err != nil { return err } @@ -144,7 +144,7 @@ func AuthFallback( } else if req.Method == http.MethodPost { // Handle Recaptcha if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { + if err := checkRecaptchaEnabled(cfg, w, req); err != nil { return err } @@ -156,7 +156,7 @@ func AuthFallback( } response := req.Form.Get("g-recaptcha-response") - if err := validateRecaptcha(&cfg, response, clientIP); err != nil { + if err := validateRecaptcha(cfg, response, clientIP); err != nil { util.GetLogger(req.Context()).Error(err) return err } diff --git a/clientapi/routing/capabilities.go b/clientapi/routing/capabilities.go new file mode 100644 index 000000000..c8743386f --- /dev/null +++ b/clientapi/routing/capabilities.go @@ -0,0 +1,51 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + + "github.com/matrix-org/util" +) + +// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) +// by building a m.room.member event then sending it to the room server +func GetCapabilities( + req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI, +) util.JSONResponse { + roomVersionsQueryReq := roomserverAPI.QueryRoomVersionCapabilitiesRequest{} + var roomVersionsQueryRes roomserverAPI.QueryRoomVersionCapabilitiesResponse + if err := queryAPI.QueryRoomVersionCapabilities( + req.Context(), + &roomVersionsQueryReq, + &roomVersionsQueryRes, + ); err != nil { + return httputil.LogThenError(req, err) + } + + response := map[string]interface{}{ + "capabilities": map[string]interface{}{ + "m.room_versions": roomVersionsQueryRes, + }, + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } +} diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 30a8c32cd..2b1245b9a 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -134,7 +134,7 @@ type fledglingEvent struct { // CreateRoom implements /createRoom func CreateRoom( req *http.Request, device *authtypes.Device, - cfg config.Dendrite, producer *producers.RoomserverProducer, + cfg *config.Dendrite, producer *producers.RoomserverProducer, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -148,7 +148,7 @@ func CreateRoom( // nolint: gocyclo func createRoom( req *http.Request, device *authtypes.Device, - cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer, + cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -344,7 +344,7 @@ func createRoom( func buildEvent( builder *gomatrixserverlib.EventBuilder, provider gomatrixserverlib.AuthEventProvider, - cfg config.Dendrite, + cfg *config.Dendrite, evTime time.Time, ) (*gomatrixserverlib.Event, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index e1c618ba3..9b8647cd4 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -40,6 +40,10 @@ type deviceUpdateJSON struct { DisplayName *string `json:"display_name"` } +type devicesDeleteJSON struct { + Devices []string `json:"devices"` +} + // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( req *http.Request, deviceDB devices.Database, device *authtypes.Device, @@ -146,3 +150,54 @@ func UpdateDeviceByID( JSON: struct{}{}, } } + +// DeleteDeviceById handles DELETE requests to /devices/{deviceId} +func DeleteDeviceById( + req *http.Request, deviceDB devices.Database, device *authtypes.Device, + deviceID string, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + ctx := req.Context() + + defer req.Body.Close() // nolint: errcheck + + if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +// DeleteDevices handles POST requests to /delete_devices +func DeleteDevices( + req *http.Request, deviceDB devices.Database, device *authtypes.Device, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + + ctx := req.Context() + payload := devicesDeleteJSON{} + + if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { + return httputil.LogThenError(req, err) + } + + defer req.Body.Close() // nolint: errcheck + + if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 7071d16f0..115286bd6 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -31,7 +31,7 @@ type getEventRequest struct { device *authtypes.Device roomID string eventID string - cfg config.Dendrite + cfg *config.Dendrite federation *gomatrixserverlib.FederationClient keyRing gomatrixserverlib.KeyRing requestedEvent gomatrixserverlib.Event @@ -44,7 +44,7 @@ func GetEvent( device *authtypes.Device, roomID string, eventID string, - cfg config.Dendrite, + cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.KeyRing, diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index e6220a308..5e6f3e559 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -39,7 +39,7 @@ func JoinRoomByIDOrAlias( req *http.Request, device *authtypes.Device, roomIDOrAlias string, - cfg config.Dendrite, + cfg *config.Dendrite, federation *gomatrixserverlib.FederationClient, producer *producers.RoomserverProducer, queryAPI roomserverAPI.RoomserverQueryAPI, @@ -98,7 +98,7 @@ type joinRoomReq struct { evTime time.Time content map[string]interface{} userID string - cfg config.Dendrite + cfg *config.Dendrite federation *gomatrixserverlib.FederationClient producer *producers.RoomserverProducer queryAPI roomserverAPI.RoomserverQueryAPI diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 0259f6e92..b8364ed9d 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -71,7 +71,7 @@ func passwordLogin() loginFlows { // Login implements GET and POST /login func Login( req *http.Request, accountDB accounts.Database, deviceDB devices.Database, - cfg config.Dendrite, + cfg *config.Dendrite, ) util.JSONResponse { if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options return util.JSONResponse{ diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 7b480cc98..68c131a2b 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -41,7 +41,7 @@ var errMissingUserID = errors.New("'user_id' must be supplied") // by building a m.room.member event then sending it to the room server func SendMembership( req *http.Request, accountDB accounts.Database, device *authtypes.Device, - roomID string, membership string, cfg config.Dendrite, + roomID string, membership string, cfg *config.Dendrite, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, producer *producers.RoomserverProducer, ) util.JSONResponse { @@ -119,7 +119,7 @@ func buildMembershipEvent( body threepid.MembershipRequest, accountDB accounts.Database, device *authtypes.Device, membership, roomID string, - cfg config.Dendrite, evTime time.Time, + cfg *config.Dendrite, evTime time.Time, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) (*gomatrixserverlib.Event, error) { stateKey, reason, err := getMembershipStateKey(body, device, membership) @@ -165,7 +165,7 @@ func buildMembershipEvent( func loadProfile( ctx context.Context, userID string, - cfg config.Dendrite, + cfg *config.Dendrite, accountDB accounts.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { @@ -214,7 +214,7 @@ func checkAndProcessThreepid( req *http.Request, device *authtypes.Device, body *threepid.MembershipRequest, - cfg config.Dendrite, + cfg *config.Dendrite, queryAPI roomserverAPI.RoomserverQueryAPI, accountDB accounts.Database, producer *producers.RoomserverProducer, diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 5b8903287..e6fca505f 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -33,7 +33,7 @@ type response struct { // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, - _ config.Dendrite, + _ *config.Dendrite, queryAPI api.RoomserverQueryAPI, ) util.JSONResponse { queryReq := api.QueryMembershipsForRoomRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index aaea49d26..9b091ddf7 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -343,7 +343,7 @@ func buildMembershipEvents( return nil, err } - event, err := common.BuildEvent(ctx, &builder, *cfg, evTime, queryAPI, nil) + event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 05d481da3..9d67d9982 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -43,6 +43,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -449,6 +450,9 @@ func Register( if resErr != nil { return *resErr } + if req.URL.Query().Get("kind") == "guest" { + return handleGuestRegistration(req, r, cfg, accountDB, deviceDB) + } // Retrieve or generate the sessionID sessionID := r.Auth.Session @@ -505,6 +509,59 @@ func Register( return handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB) } +func handleGuestRegistration( + req *http.Request, + r registerRequest, + cfg *config.Dendrite, + accountDB accounts.Database, + deviceDB devices.Database, +) util.JSONResponse { + + //Generate numeric local part for guest user + id, err := accountDB.GetNewNumericLocalpart(req.Context()) + if err != nil { + return httputil.LogThenError(req, err) + } + + localpart := strconv.FormatInt(id, 10) + acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "") + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + } + } + token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ + ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), + ServerName: string(acc.ServerName), + UserID: acc.UserID, + }) + + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("Failed to generate access token"), + } + } + //we don't allow guests to specify their own device_id + dev, err := deviceDB.CreateDevice(req.Context(), acc.Localpart, nil, token, r.InitialDisplayName) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: registerResponse{ + UserID: dev.UserID, + AccessToken: dev.AccessToken, + HomeServer: acc.ServerName, + DeviceID: dev.ID, + }, + } +} + // handleRegistrationFlow will direct and complete registration flow stages // that the client has requested. // nolint: gocyclo @@ -934,7 +991,7 @@ type availableResponse struct { // RegisterAvailable checks if the username is already taken or invalid. func RegisterAvailable( req *http.Request, - cfg config.Dendrite, + cfg *config.Dendrite, accountDB accounts.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f7b94914a..08f45e551 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -47,7 +47,7 @@ const pathPrefixUnstable = "/_matrix/client/unstable" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, cfg config.Dendrite, + apiMux *mux.Router, cfg *config.Dendrite, producer *producers.RoomserverProducer, queryAPI roomserverAPI.RoomserverQueryAPI, aliasAPI roomserverAPI.RoomserverAliasAPI, @@ -161,11 +161,11 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return Register(req, accountDB, deviceDB, &cfg) + return Register(req, accountDB, deviceDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return LegacyRegister(req, accountDB, deviceDB, &cfg) + return LegacyRegister(req, accountDB, deviceDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { @@ -178,7 +178,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI, federationSender) + return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -188,7 +188,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI) + return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -292,7 +292,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation) + return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -302,7 +302,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation) + return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -312,7 +312,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) + return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows @@ -324,7 +324,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation) + return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -334,7 +334,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) + return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows @@ -494,6 +494,22 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/devices/{deviceID}", + common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteDeviceById(req, deviceDB, device, vars["deviceID"]) + }), + ).Methods(http.MethodDelete, http.MethodOptions) + + r0mux.Handle("/delete_devices", + common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return DeleteDevices(req, deviceDB, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub implementations for sytest r0mux.Handle("/events", common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { @@ -542,4 +558,10 @@ func Setup( return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) + + r0mux.Handle("/capabilities", + common.MakeAuthAPI("capabilities", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return GetCapabilities(req, queryAPI) + }), + ).Methods(http.MethodGet) } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 76e36cd46..e6de187f2 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -43,7 +43,7 @@ func SendEvent( req *http.Request, device *authtypes.Device, roomID, eventType string, txnID, stateKey *string, - cfg config.Dendrite, + cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, txnCache *transactions.Cache, @@ -93,7 +93,7 @@ func generateSendEvent( req *http.Request, device *authtypes.Device, roomID, eventType string, stateKey *string, - cfg config.Dendrite, + cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 92505c46b..69383cdf7 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -39,7 +39,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg config.Dendrite) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -83,7 +83,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg confi // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( req *http.Request, accountDB accounts.Database, device *authtypes.Device, - cfg config.Dendrite, + cfg *config.Dendrite, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index b9121633f..872e64473 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -31,7 +31,7 @@ import ( // RequestTurnServer implements: // GET /voip/turnServer -func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse { +func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse { turnConfig := cfg.TURN // TODO Guest Support diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 9e28df4bc..aa54aa9fa 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -86,7 +86,7 @@ var ( // can be emitted. func CheckAndProcessInvite( ctx context.Context, - device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite, + device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, db accounts.Database, producer *producers.RoomserverProducer, membership string, roomID string, evTime time.Time, @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name @@ -330,7 +330,7 @@ func checkIDServerSignatures( func emit3PIDInviteEvent( ctx context.Context, body *MembershipRequest, res *idServerStoreInviteResponse, - device *authtypes.Device, roomID string, cfg config.Dendrite, + device *authtypes.Device, roomID string, cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, evTime time.Time, ) error { diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index e5b3305e3..a7f26c295 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -53,7 +53,7 @@ type Credentials struct { // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. func CreateSession( - ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite, + ctx context.Context, req EmailAssociationRequest, cfg *config.Dendrite, ) (string, error) { if err := isTrusted(req.IDServer, cfg); err != nil { return "", err @@ -101,7 +101,7 @@ func CreateSession( // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. func CheckAssociation( - ctx context.Context, creds Credentials, cfg config.Dendrite, + ctx context.Context, creds Credentials, cfg *config.Dendrite, ) (bool, string, string, error) { if err := isTrusted(creds.IDServer, cfg); err != nil { return false, "", "", err @@ -142,7 +142,7 @@ func CheckAssociation( // identifier and a Matrix ID. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. -func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) error { +func PublishAssociation(creds Credentials, userID string, cfg *config.Dendrite) error { if err := isTrusted(creds.IDServer, cfg); err != nil { return err } @@ -177,7 +177,7 @@ func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) e // isTrusted checks if a given identity server is part of the list of trusted // identity servers in the configuration file. // Returns an error if the server isn't trusted. -func isTrusted(idServer string, cfg config.Dendrite) error { +func isTrusted(idServer string, cfg *config.Dendrite) error { for _, server := range cfg.Matrix.TrustedIDServers { if idServer == server { return nil diff --git a/common/events.go b/common/events.go index 5c87c0e56..3c060ee65 100644 --- a/common/events.go +++ b/common/events.go @@ -39,7 +39,7 @@ var ErrRoomNoExists = errors.New("Room does not exist") // Returns an error if something else went wrong func BuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg config.Dendrite, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time, queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.Event, error) { err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes) diff --git a/common/keydb/postgres/server_key_table.go b/common/keydb/postgres/server_key_table.go index 8fb9a0ee9..6b13cc3c2 100644 --- a/common/keydb/postgres/server_key_table.go +++ b/common/keydb/postgres/server_key_table.go @@ -117,7 +117,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), } } - return results, nil + return results, rows.Err() } func (s *serverKeyStatements) upsertServerKeys( diff --git a/common/partition_offset_table.go b/common/partition_offset_table.go index d60971239..6bc066a69 100644 --- a/common/partition_offset_table.go +++ b/common/partition_offset_table.go @@ -99,7 +99,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( } results = append(results, offset) } - return results, nil + return results, rows.Err() } // UpsertPartitionOffset updates or inserts the partition offset for the given topic. diff --git a/common/test/config.go b/common/test/config.go index 693555619..0fed252ae 100644 --- a/common/test/config.go +++ b/common/test/config.go @@ -111,6 +111,7 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con // Bind to the same address as the listen address // All microservices are run on the same host in testing cfg.Bind.ClientAPI = cfg.Listen.ClientAPI + cfg.Bind.AppServiceAPI = cfg.Listen.AppServiceAPI cfg.Bind.FederationAPI = cfg.Listen.FederationAPI cfg.Bind.MediaAPI = cfg.Listen.MediaAPI cfg.Bind.RoomServer = cfg.Listen.RoomServer diff --git a/docker/Dockerfile b/docker/Dockerfile index 5810825a4..29b27dde2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,8 @@ +<<<<<<< HEAD FROM docker.io/golang:1.13.7-alpine3.11 +======= +FROM docker.io/golang:1.13.6-alpine +>>>>>>> master RUN mkdir /build diff --git a/docs/sytest.md b/docs/sytest.md index 6d03270bb..9385ebff3 100644 --- a/docs/sytest.md +++ b/docs/sytest.md @@ -44,6 +44,7 @@ args: user: dendrite database: dendrite host: 127.0.0.1 + sslmode: disable type: pg EOF ``` diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index aad88362c..ef57da881 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -45,8 +45,8 @@ func SetupFederationAPIComponent( roomserverProducer := producers.NewRoomserverProducer(inputAPI) routing.Setup( - base.APIMux, *base.Cfg, queryAPI, aliasAPI, asAPI, - roomserverProducer, federationSenderAPI, *keyRing, federation, accountsDB, - deviceDB, + base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI, + roomserverProducer, federationSenderAPI, *keyRing, + federation, accountsDB, deviceDB, ) } diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 5c6b0087f..cb388f50d 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -34,7 +34,7 @@ func Backfill( request *gomatrixserverlib.FederationRequest, query api.RoomserverQueryAPI, roomID string, - cfg config.Dendrite, + cfg *config.Dendrite, ) util.JSONResponse { var res api.QueryBackfillResponse var eIDs []string diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 01a1bed23..9a04a0880 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -32,7 +32,7 @@ func Invite( request *gomatrixserverlib.FederationRequest, roomID string, eventID string, - cfg config.Dendrite, + cfg *config.Dendrite, producer *producers.RoomserverProducer, keys gomatrixserverlib.KeyRing, ) util.JSONResponse { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index e2885dd99..325b99374 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -33,7 +33,7 @@ import ( func MakeJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - cfg config.Dendrite, + cfg *config.Dendrite, query api.RoomserverQueryAPI, roomID, userID string, ) util.JSONResponse { @@ -97,7 +97,7 @@ func MakeJoin( func SendJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - cfg config.Dendrite, + cfg *config.Dendrite, query api.RoomserverQueryAPI, producer *producers.RoomserverProducer, keys gomatrixserverlib.KeyRing, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 9c53d177e..3eb88567d 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -27,7 +27,7 @@ import ( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg config.Dendrite) util.JSONResponse { +func LocalKeys(cfg *config.Dendrite) util.JSONResponse { keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) if err != nil { return util.ErrorResponse(err) @@ -35,7 +35,7 @@ func LocalKeys(cfg config.Dendrite) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys keys.ServerName = cfg.Matrix.ServerName diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a982b87f8..958158084 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -31,7 +31,7 @@ import ( func MakeLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - cfg config.Dendrite, + cfg *config.Dendrite, query api.RoomserverQueryAPI, roomID, userID string, ) util.JSONResponse { @@ -95,7 +95,7 @@ func MakeLeave( func SendLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - cfg config.Dendrite, + cfg *config.Dendrite, producer *producers.RoomserverProducer, keys gomatrixserverlib.KeyRing, roomID, eventID string, diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 452c2c7d8..31b7a343f 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -31,7 +31,7 @@ import ( func GetProfile( httpReq *http.Request, accountDB accounts.Database, - cfg config.Dendrite, + cfg *config.Dendrite, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field") diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index ed2d8b741..5277f0acd 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -32,7 +32,7 @@ import ( func RoomAliasToID( httpReq *http.Request, federation *gomatrixserverlib.FederationClient, - cfg config.Dendrite, + cfg *config.Dendrite, aliasAPI roomserverAPI.RoomserverAliasAPI, senderAPI federationSenderAPI.FederationSenderQueryAPI, ) util.JSONResponse { diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 8f43fcd15..3b119301a 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -43,7 +43,7 @@ const ( // nolint: gocyclo func Setup( apiMux *mux.Router, - cfg config.Dendrite, + cfg *config.Dendrite, query roomserverAPI.RoomserverQueryAPI, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index eab248745..5513a088f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -34,7 +34,7 @@ func Send( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, txnID gomatrixserverlib.TransactionID, - cfg config.Dendrite, + cfg *config.Dendrite, query api.RoomserverQueryAPI, producer *producers.RoomserverProducer, keys gomatrixserverlib.KeyRing, diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 5f56427de..a22685f25 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -59,7 +59,7 @@ var ( // CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind func CreateInvitesFrom3PIDInvites( req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite, + asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient, accountDB accounts.Database, ) util.JSONResponse { @@ -98,7 +98,7 @@ func ExchangeThirdPartyInvite( request *gomatrixserverlib.FederationRequest, roomID string, queryAPI roomserverAPI.RoomserverQueryAPI, - cfg config.Dendrite, + cfg *config.Dendrite, federation *gomatrixserverlib.FederationClient, producer *producers.RoomserverProducer, ) util.JSONResponse { @@ -172,7 +172,7 @@ func ExchangeThirdPartyInvite( // necessary data to do so. func createInviteFrom3PIDInvite( ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite, + asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, inv invite, federation *gomatrixserverlib.FederationClient, accountDB accounts.Database, ) (*gomatrixserverlib.Event, error) { @@ -230,7 +230,7 @@ func createInviteFrom3PIDInvite( func buildMembershipEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI, - cfg config.Dendrite, + cfg *config.Dendrite, ) (*gomatrixserverlib.Event, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { @@ -290,7 +290,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation *gomatrixserverlib.FederationClient, _ config.Dendrite, + federation *gomatrixserverlib.FederationClient, _ *config.Dendrite, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index bd580e3b5..e5c30a010 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -132,5 +132,5 @@ func joinedHostsFromStmt( }) } - return result, nil + return result, rows.Err() } diff --git a/go.sum b/go.sum index 261d4f4c4..7c8732f63 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,7 @@ gopkg.in/h2non/bimg.v1 v1.0.18 h1:qn6/RpBHt+7WQqoBcK+aF2puc6nC78eZj5LexxoalT4= gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM= gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE= +gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 167e37957..127b86bb9 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -144,6 +144,7 @@ func (s *thumbnailStatements) selectThumbnails( if err != nil { return nil, err } + defer rows.Close() // nolint: errcheck var thumbnails []*types.ThumbnailMetadata for rows.Next() { @@ -167,5 +168,5 @@ func (s *thumbnailStatements) selectThumbnails( thumbnails = append(thumbnails, &thumbnailMetadata) } - return thumbnails, err + return thumbnails, rows.Err() } diff --git a/publicroomsapi/storage/postgres/public_rooms_table.go b/publicroomsapi/storage/postgres/public_rooms_table.go index 852afe770..edf9ad2ab 100644 --- a/publicroomsapi/storage/postgres/public_rooms_table.go +++ b/publicroomsapi/storage/postgres/public_rooms_table.go @@ -203,6 +203,7 @@ func (s *publicRoomsStatements) selectPublicRooms( if err != nil { return []types.PublicRoom{}, nil } + defer rows.Close() // nolint: errcheck rooms := []types.PublicRoom{} for rows.Next() { @@ -222,7 +223,7 @@ func (s *publicRoomsStatements) selectPublicRooms( rooms = append(rooms, r) } - return rooms, nil + return rooms, rows.Err() } func (s *publicRoomsStatements) selectRoomVisibility( diff --git a/publicroomsapi/storage/sqlite3/prepare.go b/publicroomsapi/storage/sqlite3/prepare.go new file mode 100644 index 000000000..482dfa2b9 --- /dev/null +++ b/publicroomsapi/storage/sqlite3/prepare.go @@ -0,0 +1,36 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" +) + +// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type statementList []struct { + statement **sql.Stmt + sql string +} + +// prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s statementList) prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } + } + return +} diff --git a/publicroomsapi/storage/sqlite3/public_rooms_table.go b/publicroomsapi/storage/sqlite3/public_rooms_table.go new file mode 100644 index 000000000..06c74a331 --- /dev/null +++ b/publicroomsapi/storage/sqlite3/public_rooms_table.go @@ -0,0 +1,277 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/publicroomsapi/types" +) + +var editableAttributes = []string{ + "aliases", + "canonical_alias", + "name", + "topic", + "world_readable", + "guest_can_join", + "avatar_url", + "visibility", +} + +const publicRoomsSchema = ` +-- Stores all of the rooms with data needed to create the server's room directory +CREATE TABLE IF NOT EXISTS publicroomsapi_public_rooms( + -- The room's ID + room_id TEXT NOT NULL PRIMARY KEY, + -- Number of joined members in the room + joined_members INTEGER NOT NULL DEFAULT 0, + -- Aliases of the room (empty array if none) + aliases TEXT[] NOT NULL DEFAULT '{}'::TEXT[], + -- Canonical alias of the room (empty string if none) + canonical_alias TEXT NOT NULL DEFAULT '', + -- Name of the room (empty string if none) + name TEXT NOT NULL DEFAULT '', + -- Topic of the room (empty string if none) + topic TEXT NOT NULL DEFAULT '', + -- Is the room world readable? + world_readable BOOLEAN NOT NULL DEFAULT false, + -- Can guest join the room? + guest_can_join BOOLEAN NOT NULL DEFAULT false, + -- URL of the room avatar (empty string if none) + avatar_url TEXT NOT NULL DEFAULT '', + -- Visibility of the room: true means the room is publicly visible, false + -- means the room is private + visibility BOOLEAN NOT NULL DEFAULT false +); +` + +const countPublicRoomsSQL = "" + + "SELECT COUNT(*) FROM publicroomsapi_public_rooms" + + " WHERE visibility = true" + +const selectPublicRoomsSQL = "" + + "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" + + " FROM publicroomsapi_public_rooms WHERE visibility = true" + + " ORDER BY joined_members DESC" + + " OFFSET $1" + +const selectPublicRoomsWithLimitSQL = "" + + "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" + + " FROM publicroomsapi_public_rooms WHERE visibility = true" + + " ORDER BY joined_members DESC" + + " OFFSET $1 LIMIT $2" + +const selectPublicRoomsWithFilterSQL = "" + + "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" + + " FROM publicroomsapi_public_rooms" + + " WHERE visibility = true" + + " AND (LOWER(name) LIKE LOWER($1)" + + " OR LOWER(topic) LIKE LOWER($1)" + + " OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" + + " ORDER BY joined_members DESC" + + " OFFSET $2" + +const selectPublicRoomsWithLimitAndFilterSQL = "" + + "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" + + " FROM publicroomsapi_public_rooms" + + " WHERE visibility = true" + + " AND (LOWER(name) LIKE LOWER($1)" + + " OR LOWER(topic) LIKE LOWER($1)" + + " OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" + + " ORDER BY joined_members DESC" + + " OFFSET $2 LIMIT $3" + +const selectRoomVisibilitySQL = "" + + "SELECT visibility FROM publicroomsapi_public_rooms" + + " WHERE room_id = $1" + +const insertNewRoomSQL = "" + + "INSERT INTO publicroomsapi_public_rooms(room_id)" + + " VALUES ($1)" + +const incrementJoinedMembersInRoomSQL = "" + + "UPDATE publicroomsapi_public_rooms" + + " SET joined_members = joined_members + 1" + + " WHERE room_id = $1" + +const decrementJoinedMembersInRoomSQL = "" + + "UPDATE publicroomsapi_public_rooms" + + " SET joined_members = joined_members - 1" + + " WHERE room_id = $1" + +const updateRoomAttributeSQL = "" + + "UPDATE publicroomsapi_public_rooms" + + " SET %s = $1" + + " WHERE room_id = $2" + +type publicRoomsStatements struct { + countPublicRoomsStmt *sql.Stmt + selectPublicRoomsStmt *sql.Stmt + selectPublicRoomsWithLimitStmt *sql.Stmt + selectPublicRoomsWithFilterStmt *sql.Stmt + selectPublicRoomsWithLimitAndFilterStmt *sql.Stmt + selectRoomVisibilityStmt *sql.Stmt + insertNewRoomStmt *sql.Stmt + incrementJoinedMembersInRoomStmt *sql.Stmt + decrementJoinedMembersInRoomStmt *sql.Stmt + updateRoomAttributeStmts map[string]*sql.Stmt +} + +func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(publicRoomsSchema) + if err != nil { + return + } + + stmts := statementList{ + {&s.countPublicRoomsStmt, countPublicRoomsSQL}, + {&s.selectPublicRoomsStmt, selectPublicRoomsSQL}, + {&s.selectPublicRoomsWithLimitStmt, selectPublicRoomsWithLimitSQL}, + {&s.selectPublicRoomsWithFilterStmt, selectPublicRoomsWithFilterSQL}, + {&s.selectPublicRoomsWithLimitAndFilterStmt, selectPublicRoomsWithLimitAndFilterSQL}, + {&s.selectRoomVisibilityStmt, selectRoomVisibilitySQL}, + {&s.insertNewRoomStmt, insertNewRoomSQL}, + {&s.incrementJoinedMembersInRoomStmt, incrementJoinedMembersInRoomSQL}, + {&s.decrementJoinedMembersInRoomStmt, decrementJoinedMembersInRoomSQL}, + } + + if err = stmts.prepare(db); err != nil { + return + } + + s.updateRoomAttributeStmts = make(map[string]*sql.Stmt) + for _, editable := range editableAttributes { + stmt := fmt.Sprintf(updateRoomAttributeSQL, editable) + if s.updateRoomAttributeStmts[editable], err = db.Prepare(stmt); err != nil { + return + } + } + + return +} + +func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) { + err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb) + return +} + +func (s *publicRoomsStatements) selectPublicRooms( + ctx context.Context, offset int64, limit int16, filter string, +) ([]types.PublicRoom, error) { + var rows *sql.Rows + var err error + + if len(filter) > 0 { + pattern := "%" + filter + "%" + if limit == 0 { + rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext( + ctx, pattern, offset, + ) + } else { + rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext( + ctx, pattern, offset, limit, + ) + } + } else { + if limit == 0 { + rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset) + } else { + rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext( + ctx, offset, limit, + ) + } + } + + if err != nil { + return []types.PublicRoom{}, nil + } + + rooms := []types.PublicRoom{} + for rows.Next() { + var r types.PublicRoom + var aliases pq.StringArray + + err = rows.Scan( + &r.RoomID, &r.NumJoinedMembers, &aliases, &r.CanonicalAlias, + &r.Name, &r.Topic, &r.WorldReadable, &r.GuestCanJoin, &r.AvatarURL, + ) + if err != nil { + return rooms, err + } + + r.Aliases = aliases + + rooms = append(rooms, r) + } + + return rooms, nil +} + +func (s *publicRoomsStatements) selectRoomVisibility( + ctx context.Context, roomID string, +) (v bool, err error) { + err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v) + return +} + +func (s *publicRoomsStatements) insertNewRoom( + ctx context.Context, roomID string, +) error { + _, err := s.insertNewRoomStmt.ExecContext(ctx, roomID) + return err +} + +func (s *publicRoomsStatements) incrementJoinedMembersInRoom( + ctx context.Context, roomID string, +) error { + _, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID) + return err +} + +func (s *publicRoomsStatements) decrementJoinedMembersInRoom( + ctx context.Context, roomID string, +) error { + _, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID) + return err +} + +func (s *publicRoomsStatements) updateRoomAttribute( + ctx context.Context, attrName string, attrValue attributeValue, roomID string, +) error { + stmt, isEditable := s.updateRoomAttributeStmts[attrName] + + if !isEditable { + return errors.New("Cannot edit " + attrName) + } + + var value interface{} + switch v := attrValue.(type) { + case []string: + value = pq.StringArray(v) + case bool, string: + value = attrValue + default: + return errors.New("Unsupported attribute type, must be bool, string or []string") + } + + _, err := stmt.ExecContext(ctx, value, roomID) + return err +} diff --git a/publicroomsapi/storage/sqlite3/storage.go b/publicroomsapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..dcb8920f9 --- /dev/null +++ b/publicroomsapi/storage/sqlite3/storage.go @@ -0,0 +1,256 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + _ "github.com/mattn/go-sqlite3" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/publicroomsapi/types" + + "github.com/matrix-org/gomatrixserverlib" +) + +// PublicRoomsServerDatabase represents a public rooms server database. +type PublicRoomsServerDatabase struct { + db *sql.DB + common.PartitionOffsetStatements + statements publicRoomsStatements +} + +type attributeValue interface{} + +// NewPublicRoomsServerDatabase creates a new public rooms server database. +func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { + var db *sql.DB + var err error + if db, err = sql.Open("sqlite3", dataSourceName); err != nil { + return nil, err + } + storage := PublicRoomsServerDatabase{ + db: db, + } + if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil { + return nil, err + } + if err = storage.statements.prepare(db); err != nil { + return nil, err + } + return &storage, nil +} + +// GetRoomVisibility returns the room visibility as a boolean: true if the room +// is publicly visible, false if not. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) GetRoomVisibility( + ctx context.Context, roomID string, +) (bool, error) { + return d.statements.selectRoomVisibility(ctx, roomID) +} + +// SetRoomVisibility updates the visibility attribute of a room. This attribute +// must be set to true if the room is publicly visible, false if not. +// Returns an error if the update failed. +func (d *PublicRoomsServerDatabase) SetRoomVisibility( + ctx context.Context, visible bool, roomID string, +) error { + return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID) +} + +// CountPublicRooms returns the number of room set as publicly visible on the server. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) { + return d.statements.countPublicRooms(ctx) +} + +// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number +// of joined members. This array can be limited by a given number of elements, and offset by a given value. +// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all +// the rooms set as publicly visible on the server. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) GetPublicRooms( + ctx context.Context, offset int64, limit int16, filter string, +) ([]types.PublicRoom, error) { + return d.statements.selectPublicRooms(ctx, offset, limit, filter) +} + +// UpdateRoomFromEvents iterate over a slice of state events and call +// UpdateRoomFromEvent on each of them to update the database representation of +// the rooms updated by each event. +// The slice of events to remove is used to update the number of joined members +// for the room in the database. +// If the update triggered by one of the events failed, aborts the process and +// returns an error. +func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents( + ctx context.Context, + eventsToAdd []gomatrixserverlib.Event, + eventsToRemove []gomatrixserverlib.Event, +) error { + for _, event := range eventsToAdd { + if err := d.UpdateRoomFromEvent(ctx, event); err != nil { + return err + } + } + + for _, event := range eventsToRemove { + if event.Type() == "m.room.member" { + if err := d.updateNumJoinedUsers(ctx, event, true); err != nil { + return err + } + } + } + + return nil +} + +// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by +// checking the event's type to know which attribute to change and using the event's content +// to define the new value of the attribute. +// If the event doesn't match with any property used to compute the public room directory, +// does nothing. +// If something went wrong during the process, returns an error. +func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( + ctx context.Context, event gomatrixserverlib.Event, +) error { + // Process the event according to its type + switch event.Type() { + case "m.room.create": + return d.statements.insertNewRoom(ctx, event.RoomID()) + case "m.room.member": + return d.updateNumJoinedUsers(ctx, event, false) + case "m.room.aliases": + return d.updateRoomAliases(ctx, event) + case "m.room.canonical_alias": + var content common.CanonicalAliasContent + field := &(content.Alias) + attrName := "canonical_alias" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.name": + var content common.NameContent + field := &(content.Name) + attrName := "name" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.topic": + var content common.TopicContent + field := &(content.Topic) + attrName := "topic" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.avatar": + var content common.AvatarContent + field := &(content.URL) + attrName := "avatar_url" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.history_visibility": + var content common.HistoryVisibilityContent + field := &(content.HistoryVisibility) + attrName := "world_readable" + strForTrue := "world_readable" + return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) + case "m.room.guest_access": + var content common.GuestAccessContent + field := &(content.GuestAccess) + attrName := "guest_can_join" + strForTrue := "can_join" + return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) + } + + // If the event type didn't match, return with no error + return nil +} + +// updateNumJoinedUsers updates the number of joined user in the database representation +// of a room using a given "m.room.member" Matrix event. +// If the membership property of the event isn't "join", ignores it and returs nil. +// If the remove parameter is set to false, increments the joined members counter in the +// database, if set to truem decrements it. +// Returns an error if the update failed. +func (d *PublicRoomsServerDatabase) updateNumJoinedUsers( + ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool, +) error { + membership, err := membershipEvent.Membership() + if err != nil { + return err + } + + if membership != gomatrixserverlib.Join { + return nil + } + + if remove { + return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) + } + return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) +} + +// updateStringAttribute updates a given string attribute in the database +// representation of a room using a given string data field from content of the +// Matrix event triggering the update. +// Returns an error if decoding the Matrix event's content or updating the attribute +// failed. +func (d *PublicRoomsServerDatabase) updateStringAttribute( + ctx context.Context, attrName string, event gomatrixserverlib.Event, + content interface{}, field *string, +) error { + if err := json.Unmarshal(event.Content(), content); err != nil { + return err + } + + return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID()) +} + +// updateBooleanAttribute updates a given boolean attribute in the database +// representation of a room using a given string data field from content of the +// Matrix event triggering the update. +// The attribute is set to true if the field matches a given string, false if not. +// Returns an error if decoding the Matrix event's content or updating the attribute +// failed. +func (d *PublicRoomsServerDatabase) updateBooleanAttribute( + ctx context.Context, attrName string, event gomatrixserverlib.Event, + content interface{}, field *string, strForTrue string, +) error { + if err := json.Unmarshal(event.Content(), content); err != nil { + return err + } + + var attrValue bool + if *field == strForTrue { + attrValue = true + } else { + attrValue = false + } + + return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID()) +} + +// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of +// a given room with it. +// Returns an error if decoding the Matrix event or updating the list failed. +func (d *PublicRoomsServerDatabase) updateRoomAliases( + ctx context.Context, aliasesEvent gomatrixserverlib.Event, +) error { + var content common.AliasesContent + if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { + return err + } + + return d.statements.updateRoomAttribute( + ctx, "aliases", content.Aliases, aliasesEvent.RoomID(), + ) +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b3fa01840..e1850e723 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -244,6 +246,15 @@ type QueryServersInRoomAtEventResponse struct { Servers []gomatrixserverlib.ServerName `json:"servers"` } +// QueryRoomVersionCapabilities asks for the default room version +type QueryRoomVersionCapabilitiesRequest struct{} + +// QueryRoomVersionCapabilitiesResponse is a response to QueryServersInRoomAtEventResponse +type QueryRoomVersionCapabilitiesResponse struct { + DefaultRoomVersion string `json:"default"` + AvailableRoomVersions map[string]string `json:"available"` +} + // RoomserverQueryAPI is used to query information from the room server. type RoomserverQueryAPI interface { // Query the latest events and state for a room from the room server. @@ -323,6 +334,13 @@ type RoomserverQueryAPI interface { request *QueryServersInRoomAtEventRequest, response *QueryServersInRoomAtEventResponse, ) error + + // Asks for the default room version as preferred by the server. + QueryRoomVersionCapabilities( + ctx context.Context, + request *QueryRoomVersionCapabilitiesRequest, + response *QueryRoomVersionCapabilitiesResponse, + ) error } // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. @@ -358,6 +376,9 @@ const RoomserverQueryBackfillPath = "/api/roomserver/queryBackfill" // RoomserverQueryServersInRoomAtEventPath is the HTTP path for the QueryServersInRoomAtEvent API const RoomserverQueryServersInRoomAtEventPath = "/api/roomserver/queryServersInRoomAtEvents" +// RoomserverQueryRoomVersionCapabilitiesPath is the HTTP path for the QueryRoomVersionCapabilities API +const RoomserverQueryRoomVersionCapabilitiesPath = "/api/roomserver/queryRoomVersionCapabilities" + // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil then it uses the http.DefaultClient func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { @@ -514,3 +535,16 @@ func (h *httpRoomserverQueryAPI) QueryServersInRoomAtEvent( apiURL := h.roomserverURL + RoomserverQueryServersInRoomAtEventPath return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// QueryServersInRoomAtEvent implements RoomServerQueryAPI +func (h *httpRoomserverQueryAPI) QueryRoomVersionCapabilities( + ctx context.Context, + request *QueryRoomVersionCapabilitiesRequest, + response *QueryRoomVersionCapabilitiesResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/input/events.go b/roomserver/input/events.go index b30c39928..03023a4af 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,13 +23,14 @@ import ( "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/state/database" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { - state.RoomStateDatabase + database.RoomStateDatabase // Stores a matrix room event in the database StoreEvent( ctx context.Context, @@ -149,7 +152,12 @@ func calculateAndSetState( stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { - var err error + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, db) + if err != nil { + return err + } + if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. @@ -163,7 +171,7 @@ func calculateAndSetState( } } else { // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { return err } } diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go index c2f06393f..7e03d544a 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/input/latest_events.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -171,27 +173,32 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, u.db) + if err != nil { + return err + } latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) for i := range u.latest { latestStateAtEvents[i] = u.latest[i].StateAtEvent } - u.newStateNID, err = state.CalculateAndStoreStateAfterEvents( - u.ctx, u.db, u.roomNID, latestStateAtEvents, + u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( + u.ctx, u.roomNID, latestStateAtEvents, ) if err != nil { return err } - u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots( - u.ctx, u.db, u.oldStateNID, u.newStateNID, + u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( + u.ctx, u.oldStateNID, u.newStateNID, ) if err != nil { return err } - u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots( - u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, + u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( + u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) return err } diff --git a/roomserver/query/query.go b/roomserver/query/query.go index da8fe23e5..f138686b5 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,12 +20,15 @@ import ( "context" "encoding/json" "net/http" + "strconv" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/state/database" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -39,7 +44,7 @@ type RoomserverQueryAPIEventDB interface { // RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. type RoomserverQueryAPIDatabase interface { - state.RoomStateDatabase + database.RoomStateDatabase RoomserverQueryAPIEventDB // Look up the numeric ID for the room. // Returns 0 if the room doesn't exists. @@ -98,6 +103,11 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return err + } response.QueryLatestEventsAndStateRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -115,8 +125,8 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( } // Look up the currrent state for the requested tuples. - stateEntries, err := state.LoadStateAtSnapshotForStringTuples( - ctx, r.DB, currentStateSnapshotNID, request.StateToFetch, + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, currentStateSnapshotNID, request.StateToFetch, ) if err != nil { return err @@ -137,6 +147,11 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return err + } response.QueryStateAfterEventsRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -159,8 +174,8 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( response.PrevEventsExist = true // Look up the currrent state for the requested tuples. - stateEntries, err := state.LoadStateAfterEventsForStringTuples( - ctx, r.DB, prevStates, request.StateToFetch, + stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( + ctx, prevStates, request.StateToFetch, ) if err != nil { return err @@ -315,6 +330,11 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( ctx context.Context, eventNID types.EventNID, joinedOnly bool, ) ([]types.Event, error) { + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return []types.Event{}, err + } events := []types.Event{} // Lookup the event NID eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) @@ -329,7 +349,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( } // Fetch the state as it was when this event was fired - stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState) + stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState) if err != nil { return nil, err } @@ -416,7 +436,13 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, ) (bool, error) { - stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, eventID) + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return false, err + } + + stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err } @@ -570,6 +596,12 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return err + } + response.QueryStateAndAuthChainRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -592,8 +624,8 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( response.PrevEventsExist = true // Look up the currrent state for the requested tuples. - stateEntries, err := state.LoadCombinedStateAfterEvents( - ctx, r.DB, prevStates, + stateEntries, err := roomState.LoadCombinedStateAfterEvents( + ctx, prevStates, ) if err != nil { return err @@ -695,6 +727,25 @@ func (r *RoomserverQueryAPI) QueryServersInRoomAtEvent( return nil } +// QueryRoomVersionCapabilities implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryRoomVersionCapabilities( + ctx context.Context, + request *api.QueryRoomVersionCapabilitiesRequest, + response *api.QueryRoomVersionCapabilitiesResponse, +) error { + response.DefaultRoomVersion = strconv.Itoa(int(version.GetDefaultRoomVersion())) + response.AvailableRoomVersions = make(map[string]string) + for v, desc := range version.GetSupportedRoomVersions() { + sv := strconv.Itoa(int(v)) + if desc.Stable { + response.AvailableRoomVersions[sv] = "stable" + } else { + response.AvailableRoomVersions[sv] = "unstable" + } + } + return nil +} + // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. // nolint: gocyclo func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { @@ -852,4 +903,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + servMux.Handle( + api.RoomserverQueryRoomVersionCapabilitiesPath, + common.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { + var request api.QueryRoomVersionCapabilitiesRequest + var response api.QueryRoomVersionCapabilitiesResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/state/database/database.go b/roomserver/state/database/database.go new file mode 100644 index 000000000..ede6c5ec3 --- /dev/null +++ b/roomserver/state/database/database.go @@ -0,0 +1,64 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +// A RoomStateDatabase has the storage APIs needed to load state from the database +type RoomStateDatabase interface { + // Store the room state at an event in the database + AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, + ) (types.StateSnapshotNID, error) + // Look up the state of a room at each event for a list of string event IDs. + // Returns an error if there is an error talking to the database + // Returns a types.MissingEventError if the room state for the event IDs aren't in the database + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + // Look up the numeric IDs for a list of string event types. + // Returns a map from string event type to numeric ID for the event type. + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + // Look up the numeric IDs for a list of string event state keys. + // Returns a map from string state key to numeric ID for the state key. + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + // Look up the numeric state data IDs for each numeric state snapshot ID + // The returned slice is sorted by numeric state snapshot ID. + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + // Look up the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + // Look up the state data for the state key tuples for each numeric state block ID + // This is used to fetch a subset of the room state at a snapshot. + // If a block doesn't contain any of the requested tuples then it can be discarded from the result. + // The returned slice is sorted by numeric state block ID. + StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, + ) ([]types.StateEntryList, error) + // Look up the Events for a list of numeric event IDs. + // Returns a sorted list of events. + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + // Look up snapshot NID for an event ID string + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1cbb4d12b..687a120e3 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,955 +14,68 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package state provides functions for reading state from the database. -// The functions for writing state to the database are the input package. package state import ( "context" - "fmt" - "sort" - "time" + "errors" + + "github.com/matrix-org/dendrite/roomserver/state/database" + v1 "github.com/matrix-org/dendrite/roomserver/state/v1" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" ) -// A RoomStateDatabase has the storage APIs needed to load state from the database -type RoomStateDatabase interface { - // Store the room state at an event in the database - AddState( +type StateResolutionVersion int + +const ( + StateResolutionAlgorithmV1 StateResolutionVersion = iota + 1 + StateResolutionAlgorithmV2 +) + +func GetStateResolutionAlgorithm( + version StateResolutionVersion, db database.RoomStateDatabase, +) (StateResolutionImpl, error) { + switch version { + case StateResolutionAlgorithmV1: + return v1.Prepare(db), nil + default: + return nil, errors.New("unsupported room version") + } +} + +type StateResolutionImpl interface { + LoadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, + ) ([]types.StateEntry, error) + LoadStateAtEvent( + ctx context.Context, eventID string, + ) ([]types.StateEntry, error) + LoadCombinedStateAfterEvents( + ctx context.Context, prevStates []types.StateAtEvent, + ) ([]types.StateEntry, error) + DifferenceBetweeenStateSnapshots( + ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, + ) (removed, added []types.StateEntry, err error) + LoadStateAtSnapshotForStringTuples( + ctx context.Context, + stateNID types.StateSnapshotNID, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, + ) ([]types.StateEntry, error) + LoadStateAfterEventsForStringTuples( + ctx context.Context, + prevStates []types.StateAtEvent, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, + ) ([]types.StateEntry, error) + CalculateAndStoreStateBeforeEvent( + ctx context.Context, + event gomatrixserverlib.Event, + roomNID types.RoomNID, + ) (types.StateSnapshotNID, error) + CalculateAndStoreStateAfterEvents( ctx context.Context, roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, + prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) - // Look up the state of a room at each event for a list of string event IDs. - // Returns an error if there is an error talking to the database - // Returns a types.MissingEventError if the room state for the event IDs aren't in the database - StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) - // Look up the numeric IDs for a list of string event types. - // Returns a map from string event type to numeric ID for the event type. - EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) - // Look up the numeric IDs for a list of string event state keys. - // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - // Look up the numeric state data IDs for each numeric state snapshot ID - // The returned slice is sorted by numeric state snapshot ID. - StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) - // Look up the state data for each numeric state data ID - // The returned slice is sorted by numeric state data ID. - StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - // Look up the state data for the state key tuples for each numeric state block ID - // This is used to fetch a subset of the room state at a snapshot. - // If a block doesn't contain any of the requested tuples then it can be discarded from the result. - // The returned slice is sorted by numeric state block ID. - StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, - ) ([]types.StateEntryList, error) - // Look up the Events for a list of numeric event IDs. - // Returns a sorted list of events. - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) - // Look up snapshot NID for an event ID string - SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) -} - -// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func LoadStateAtSnapshot( - ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID, -) ([]types.StateEntry, error) { - stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -// LoadStateAtEvent loads the full state of a room at a particular event. -func LoadStateAtEvent( - ctx context.Context, db RoomStateDatabase, eventID string, -) ([]types.StateEntry, error) { - snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID) - if err != nil { - return nil, err - } - - stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID) - if err != nil { - return nil, err - } - - return stateEntries, nil -} - -// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events -// and combines those snapshots together into a single list. -func LoadCombinedStateAfterEvents( - ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, -) ([]types.StateEntry, error) { - stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) - for i, state := range prevStates { - stateNIDs[i] = state.BeforeStateSnapshotNID - } - // Fetch the state snapshots for the state before the each prev event from the database. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because the events could be state events where - // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) - if err != nil { - return nil, err - } - - var stateBlockNIDs []types.StateBlockNID - for _, list := range stateBlockNIDLists { - stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) - } - // Fetch the state entries that will be combined to create the snapshots. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because a block of state entries could be reused by - // multiple snapshots. - stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) - if err != nil { - return nil, err - } - stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine the entries from all the snapshots of state after each prev event into a single list. - var combined []types.StateEntry - for _, prevState := range prevStates { - // Grab the list of state data NIDs for this snapshot. - stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) - } - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - if prevState.IsStateEvent() { - // If the prev event was a state event then add an entry for the event itself - // so that we get the state after the event rather than the state before. - fullState = append(fullState, prevState.StateEntry) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - // Add the full state for this StateSnapshotNID. - combined = append(combined, fullState...) - } - return combined, nil -} - -// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. -func DifferenceBetweeenStateSnapshots( - ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID, -) (removed, added []types.StateEntry, err error) { - if oldStateNID == newStateNID { - // If the snapshot NIDs are the same then nothing has changed - return nil, nil, nil - } - - var oldEntries []types.StateEntry - var newEntries []types.StateEntry - if oldStateNID != 0 { - oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID) - if err != nil { - return nil, nil, err - } - } - if newStateNID != 0 { - newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID) - if err != nil { - return nil, nil, err - } - } - - var oldI int - var newI int - for { - switch { - case oldI == len(oldEntries): - // We've reached the end of the old entries. - // The rest of the new list must have been newly added. - added = append(added, newEntries[newI:]...) - return - case newI == len(newEntries): - // We've reached the end of the new entries. - // The rest of the old list must be have been removed. - removed = append(removed, oldEntries[oldI:]...) - return - case oldEntries[oldI] == newEntries[newI]: - // The entry is in both lists so skip over it. - oldI++ - newI++ - case oldEntries[oldI].LessThan(newEntries[newI]): - // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. - removed = append(removed, oldEntries[oldI]) - oldI++ - default: - // Reaching the default case implies that the new entry is less than the old entry. - // Since the lists are sorted this means that it only appears in the new list. - added = append(added, newEntries[newI]) - newI++ - } - } -} - -// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. -// This is used when we only want to load a subset of the room state at a snapshot. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func LoadStateAtSnapshotForStringTuples( - ctx context.Context, - db RoomStateDatabase, - stateNID types.StateSnapshotNID, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateEntry, error) { - numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) - if err != nil { - return nil, err - } - return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples) -} - -// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs -// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. -// Returns an error if there was a problem talking to the database. -func stringTuplesToNumericTuples( - ctx context.Context, - db RoomStateDatabase, - stringTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateKeyTuple, error) { - eventTypes := make([]string, len(stringTuples)) - stateKeys := make([]string, len(stringTuples)) - for i := range stringTuples { - eventTypes[i] = stringTuples[i].EventType - stateKeys[i] = stringTuples[i].StateKey - } - eventTypes = util.UniqueStrings(eventTypes) - eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes) - if err != nil { - return nil, err - } - stateKeys = util.UniqueStrings(stateKeys) - stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys) - if err != nil { - return nil, err - } - - var result []types.StateKeyTuple - for _, stringTuple := range stringTuples { - var numericTuple types.StateKeyTuple - var ok1, ok2 bool - numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] - numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] - // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. - if ok1 && ok2 { - result = append(result, numericTuple) - } - } - - return result, nil -} - -// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. -// This is used when we only want to load a subset of the room state at a snapshot. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func loadStateAtSnapshotForNumericTuples( - ctx context.Context, - db RoomStateDatabase, - stateNID types.StateSnapshotNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntry, error) { - stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := db.StateEntriesForTuples( - ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, - ) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // If the block is missing from the map it means that none of its entries matched a requested tuple. - // This can happen if the block doesn't contain an update for one of the requested tuples. - // If none of the requested tuples are in the block then it can be safely skipped. - continue - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -// LoadStateAfterEventsForStringTuples loads the state for a list of event type -// and state key pairs after list of events. -// This is used when we only want to load a subset of the room state after a list of events. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func LoadStateAfterEventsForStringTuples( - ctx context.Context, - db RoomStateDatabase, - prevStates []types.StateAtEvent, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateEntry, error) { - numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) - if err != nil { - return nil, err - } - return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples) -} - -func loadStateAfterEventsForNumericTuples( - ctx context.Context, - db RoomStateDatabase, - prevStates []types.StateAtEvent, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntry, error) { - if len(prevStates) == 1 { - // Fast path for a single event. - prevState := prevStates[0] - result, err := loadStateAtSnapshotForNumericTuples( - ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples, - ) - if err != nil { - return nil, err - } - if prevState.IsStateEvent() { - // The result is current the state before the requested event. - // We want the state after the requested event. - // If the requested event was a state event then we need to - // update that key in the result. - // If the requested event wasn't a state event then the state after - // it is the same as the state before it. - for i := range result { - if result[i].StateKeyTuple == prevState.StateKeyTuple { - result[i] = prevState.StateEntry - } - } - } - return result, nil - } - - // Slow path for more that one event. - // Load the entire state so that we can do conflict resolution if we need to. - // TODO: The are some optimistations we could do here: - // 1) We only need to do conflict resolution if there is a conflict in the - // requested tuples so we might try loading just those tuples and then - // checking for conflicts. - // 2) When there is a conflict we still only need to load the state - // needed to do conflict resolution which would save us having to load - // the full state. - - // TODO: Add metrics for this as it could take a long time for big rooms - // with large conflicts. - fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates) - if err != nil { - return nil, err - } - - // Sort the full state so we can use it as a map. - sort.Sort(stateEntrySorter(fullState)) - - // Filter the full state down to the required tuples. - var result []types.StateEntry - for _, tuple := range stateKeyTuples { - eventNID, ok := stateEntryMap(fullState).lookup(tuple) - if ok { - result = append(result, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } - sort.Sort(stateEntrySorter(result)) - return result, nil -} - -var calculateStateDurations = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_duration_microseconds", - Help: "How long it takes to calculate the state after a list of events", - }, - // Takes two labels: - // algorithm: - // The algorithm used to calculate the state or the step it failed on if it failed. - // Labels starting with "_" are used to indicate when the algorithm fails halfway. - // outcome: - // Whether the state was successfully calculated. - // - // The possible values for algorithm are: - // empty_state -> The list of events was empty so the state is empty. - // no_change -> The state hasn't changed. - // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta - // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts - // while doing so. - // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. - // _load_state_block_nids -> Failed loading the state block nids for a single previous state. - // _load_combined_state -> Failed to load the combined state. - // _resolve_conflicts -> Failed to resolve conflicts. - []string{"algorithm", "outcome"}, -) - -var calculateStatePrevEventLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_prev_event_length", - Help: "The length of the list of events to calculate the state after", - }, - []string{"algorithm", "outcome"}, -) - -var calculateStateFullStateLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_full_state_length", - Help: "The length of the full room state.", - }, - []string{"algorithm", "outcome"}, -) - -var calculateStateConflictLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_conflict_state_length", - Help: "The length of the conflicted room state.", - }, - []string{"algorithm", "outcome"}, -) - -type calculateStateMetrics struct { - algorithm string - startTime time.Time - prevEventLength int - fullStateLength int - conflictLength int -} - -func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { - var outcome string - if err == nil { - outcome = "success" - } else { - outcome = "failure" - } - endTime := time.Now() - calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( - float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000., - ) - calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.prevEventLength), - ) - calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.fullStateLength), - ) - calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.conflictLength), - ) - return stateNID, err -} - -func init() { - prometheus.MustRegister( - calculateStateDurations, calculateStatePrevEventLength, - calculateStateFullStateLength, calculateStateConflictLength, - ) -} - -// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. -// Stores the snapshot of the state in the database. -// Returns a numeric ID for the snapshot of the state before the event. -func CalculateAndStoreStateBeforeEvent( - ctx context.Context, - db RoomStateDatabase, - event gomatrixserverlib.Event, - roomNID types.RoomNID, -) (types.StateSnapshotNID, error) { - // Load the state at the prev events. - prevEventRefs := event.PrevEvents() - prevEventIDs := make([]string, len(prevEventRefs)) - for i := range prevEventRefs { - prevEventIDs[i] = prevEventRefs[i].EventID - } - - prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs) - if err != nil { - return 0, err - } - - // The state before this event will be the state after the events that came before it. - return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates) -} - -// CalculateAndStoreStateAfterEvents finds the room state after the given events. -// Stores the resulting state in the database and returns a numeric ID for that snapshot. -func CalculateAndStoreStateAfterEvents( - ctx context.Context, - db RoomStateDatabase, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, -) (types.StateSnapshotNID, error) { - metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} - if len(prevStates) == 0 { - // 2) There weren't any prev_events for this event so the state is - // empty. - metrics.algorithm = "empty_state" - return metrics.stop(db.AddState(ctx, roomNID, nil, nil)) - } - - if len(prevStates) == 1 { - prevState := prevStates[0] - if prevState.EventStateKeyNID == 0 { - // 3) None of the previous events were state events and they all - // have the same state, so this event has exactly the same state - // as the previous events. - // This should be the common case. - metrics.algorithm = "no_change" - return metrics.stop(prevState.BeforeStateSnapshotNID, nil) - } - - // The previous event was a state event so we need to store a copy - // of the previous state updated with that event. - stateBlockNIDLists, err := db.StateBlockNIDs( - ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, - ) - if err != nil { - metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, err) - } - stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs - if len(stateBlockNIDs) < maxStateBlockNIDs { - // 4) The number of state data blocks is small enough that we can just - // add the state event as a block of size one to the end of the blocks. - metrics.algorithm = "single_delta" - return metrics.stop(db.AddState( - ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, - )) - } - // If there are too many deltas then we need to calculate the full state - // So fall through to calculateAndStoreStateAfterManyEvents - } - - return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics) -} - -// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. -// Increasing this number means that we can encode more of the state changes as simple deltas which means that -// we need fewer entries in the state data table. However making this number bigger will increase the size of -// the rows in the state table itself and will require more index lookups when retrieving a snapshot. -// TODO: Tune this to get the right balance between size and lookup performance. -const maxStateBlockNIDs = 64 - -// calculateAndStoreStateAfterManyEvents finds the room state after the given events. -// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. -// Stores the resulting state and returns a numeric ID for the snapshot. -func calculateAndStoreStateAfterManyEvents( - ctx context.Context, - db RoomStateDatabase, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, - metrics calculateStateMetrics, -) (types.StateSnapshotNID, error) { - - state, algorithm, conflictLength, err := - calculateStateAfterManyEvents(ctx, db, prevStates) - metrics.algorithm = algorithm - if err != nil { - return metrics.stop(0, err) - } - - // TODO: Check if we can encode the new state as a delta against the - // previous state. - metrics.conflictLength = conflictLength - metrics.fullStateLength = len(state) - return metrics.stop(db.AddState(ctx, roomNID, nil, state)) -} - -func calculateStateAfterManyEvents( - ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, -) (state []types.StateEntry, algorithm string, conflictLength int, err error) { - var combined []types.StateEntry - // Conflict resolution. - // First stage: load the state after each of the prev events. - combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates) - if err != nil { - algorithm = "_load_combined_state" - return - } - - // Collect all the entries with the same type and key together. - // We don't care about the order here because the conflict resolution - // algorithm doesn't depend on the order of the prev events. - // Remove duplicate entires. - combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] - - // Find the conflicts - conflicts := findDuplicateStateKeys(combined) - - if len(conflicts) > 0 { - conflictLength = len(conflicts) - - // 5) There are conflicting state events, for each conflict workout - // what the appropriate state event is. - - // Work out which entries aren't conflicted. - var notConflicted []types.StateEntry - for _, entry := range combined { - if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { - notConflicted = append(notConflicted, entry) - } - } - - var resolved []types.StateEntry - resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts) - if err != nil { - algorithm = "_resolve_conflicts" - return - } - algorithm = "full_state_with_conflicts" - state = resolved - } else { - algorithm = "full_state_no_conflicts" - // 6) There weren't any conflicts - state = combined - } - return -} - -// resolveConflicts resolves a list of conflicted state entries. It takes two lists. -// The first is a list of all state entries that are not conflicted. -// The second is a list of all state entries that are conflicted -// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. -// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. -// The returned list is sorted by state key tuple. -// Returns an error if there was a problem talking to the database. -func resolveConflicts( - ctx context.Context, - db RoomStateDatabase, - notConflicted, conflicted []types.StateEntry, -) ([]types.StateEntry, error) { - - // Load the conflicted events - conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted) - if err != nil { - return nil, err - } - - // Work out which auth events we need to load. - needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) - - // Find the numeric IDs for the necessary state keys. - var neededStateKeys []string - neededStateKeys = append(neededStateKeys, needed.Member...) - neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys) - if err != nil { - return nil, err - } - - // Load the necessary auth events. - tuplesNeeded := stateKeyTuplesNeeded(stateKeyNIDMap, needed) - var authEntries []types.StateEntry - for _, tuple := range tuplesNeeded { - if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { - authEntries = append(authEntries, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } - authEvents, _, err := loadStateEvents(ctx, db, authEntries) - if err != nil { - return nil, err - } - - // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) - - // Map from the full events back to numeric state entries. - for _, resolvedEvent := range resolvedEvents { - entry, ok := eventIDMap[resolvedEvent.EventID()] - if !ok { - panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) - } - notConflicted = append(notConflicted, entry) - } - - // Sort the result so it can be searched. - sort.Sort(stateEntrySorter(notConflicted)) - return notConflicted, nil -} - -// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. -func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { - var keyTuples []types.StateKeyTuple - if stateNeeded.Create { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomCreateNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - if stateNeeded.PowerLevels { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomPowerLevelsNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - if stateNeeded.JoinRules { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomJoinRulesNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - for _, member := range stateNeeded.Member { - stateKeyNID, ok := stateKeyNIDMap[member] - if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomMemberNID, - EventStateKeyNID: stateKeyNID, - }) - } - } - for _, token := range stateNeeded.ThirdPartyInvite { - stateKeyNID, ok := stateKeyNIDMap[token] - if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomThirdPartyInviteNID, - EventStateKeyNID: stateKeyNID, - }) - } - } - return keyTuples -} - -// loadStateEvents loads the matrix events for a list of state entries. -// Returns a list of state events in no particular order and a map from string event ID back to state entry. -// The map can be used to recover which numeric state entry a given event is for. -// Returns an error if there was a problem talking to the database. -func loadStateEvents( - ctx context.Context, db RoomStateDatabase, entries []types.StateEntry, -) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { - eventNIDs := make([]types.EventNID, len(entries)) - for i := range entries { - eventNIDs[i] = entries[i].EventNID - } - events, err := db.Events(ctx, eventNIDs) - if err != nil { - return nil, nil, err - } - eventIDMap := map[string]types.StateEntry{} - result := make([]gomatrixserverlib.Event, len(entries)) - for i := range entries { - event, ok := eventMap(events).lookup(entries[i].EventNID) - if !ok { - panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) - } - result[i] = event.Event - eventIDMap[event.Event.EventID()] = entries[i] - } - return result, eventIDMap, nil -} - -// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. -// Returns a sorted list of those state entries. -func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { - var result []types.StateEntry - // j is the starting index of a block of entries with the same state key tuple. - j := 0 - for i := 1; i < len(a); i++ { - // Check if the state key tuple matches the start of the block - if a[j].StateKeyTuple != a[i].StateKeyTuple { - // If the state key tuple is different then we've reached the end of a block of duplicates. - // Check if the size of the block is bigger than one. - // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result - if j+1 != i { - // Add the block to the result. - result = append(result, a[j:i]...) - } - // Start a new block for the next state key tuple. - j = i - } - } - // Check if the last block with the same state key tuple had more than one event in it. - if j+1 != len(a) { - result = append(result, a[j:]...) - } - return result -} - -type stateEntrySorter []types.StateEntry - -func (s stateEntrySorter) Len() int { return len(s) } -func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateBlockNIDListMap []types.StateBlockNIDList - -func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { - list := []types.StateBlockNIDList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateSnapshotNID >= stateNID - }) - if i < len(list) && list[i].StateSnapshotNID == stateNID { - ok = true - stateBlockNIDs = list[i].StateBlockNIDs - } - return -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - -type stateEntryByStateKeySorter []types.StateEntry - -func (s stateEntryByStateKeySorter) Len() int { return len(s) } -func (s stateEntryByStateKeySorter) Less(i, j int) bool { - return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) -} -func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateNIDSorter []types.StateSnapshotNID - -func (s stateNIDSorter) Len() int { return len(s) } -func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { - return nids[:util.SortAndUnique(stateNIDSorter(nids))] -} - -type stateBlockNIDSorter []types.StateBlockNID - -func (s stateBlockNIDSorter) Len() int { return len(s) } -func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { - return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] -} - -// Map from event type, state key tuple to numeric event ID. -// Implemented using binary search on a sorted array. -type stateEntryMap []types.StateEntry - -// lookup an entry in the event map. -func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size and are controlled by us. - list := []types.StateEntry(m) - i := sort.Search(len(list), func(i int) bool { - return !list[i].StateKeyTuple.LessThan(stateKey) - }) - if i < len(list) && list[i].StateKeyTuple == stateKey { - ok = true - eventNID = list[i].EventNID - } - return -} - -// Map from numeric event ID to event. -// Implemented using binary search on a sorted array. -type eventMap []types.Event - -// lookup an entry in the event map. -func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size are controlled by us. - list := []types.Event(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].EventNID >= eventNID - }) - if i < len(list) && list[i].EventNID == eventNID { - ok = true - event = &list[i] - } - return } diff --git a/roomserver/state/v1/state.go b/roomserver/state/v1/state.go new file mode 100644 index 000000000..5683745bf --- /dev/null +++ b/roomserver/state/v1/state.go @@ -0,0 +1,927 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package state provides functions for reading state from the database. +// The functions for writing state to the database are the input package. +package v1 + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/matrix-org/dendrite/roomserver/state/database" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" +) + +type StateResolutionV1 struct { + db database.RoomStateDatabase +} + +func Prepare(db database.RoomStateDatabase) StateResolutionV1 { + return StateResolutionV1{ + db: db, + } +} + +// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolutionV1) LoadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAtEvent loads the full state of a room at a particular event. +func (v StateResolutionV1) LoadStateAtEvent( + ctx context.Context, eventID string, +) ([]types.StateEntry, error) { + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, err + } + + stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) + if err != nil { + return nil, err + } + + return stateEntries, nil +} + +// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events +// and combines those snapshots together into a single list. +func (v StateResolutionV1) LoadCombinedStateAfterEvents( + ctx context.Context, prevStates []types.StateAtEvent, +) ([]types.StateEntry, error) { + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateSnapshotNID + } + // Fetch the state snapshots for the state before the each prev event from the database. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because the events could be state events where + // the snapshot of the room state before them was the same. + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) + if err != nil { + return nil, err + } + + var stateBlockNIDs []types.StateBlockNID + for _, list := range stateBlockNIDLists { + stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) + } + // Fetch the state entries that will be combined to create the snapshots. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because a block of state entries could be reused by + // multiple snapshots. + stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) + if err != nil { + return nil, err + } + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine the entries from all the snapshots of state after each prev event into a single list. + var combined []types.StateEntry + for _, prevState := range prevStates { + // Grab the list of state data NIDs for this snapshot. + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + if prevState.IsStateEvent() { + // If the prev event was a state event then add an entry for the event itself + // so that we get the state after the event rather than the state before. + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateSnapshotNID. + combined = append(combined, fullState...) + } + return combined, nil +} + +// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. +func (v StateResolutionV1) DifferenceBetweeenStateSnapshots( + ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, +) (removed, added []types.StateEntry, err error) { + if oldStateNID == newStateNID { + // If the snapshot NIDs are the same then nothing has changed + return nil, nil, nil + } + + var oldEntries []types.StateEntry + var newEntries []types.StateEntry + if oldStateNID != 0 { + oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) + if err != nil { + return nil, nil, err + } + } + if newStateNID != 0 { + newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) + if err != nil { + return nil, nil, err + } + } + + var oldI int + var newI int + for { + switch { + case oldI == len(oldEntries): + // We've reached the end of the old entries. + // The rest of the new list must have been newly added. + added = append(added, newEntries[newI:]...) + return + case newI == len(newEntries): + // We've reached the end of the new entries. + // The rest of the old list must be have been removed. + removed = append(removed, oldEntries[oldI:]...) + return + case oldEntries[oldI] == newEntries[newI]: + // The entry is in both lists so skip over it. + oldI++ + newI++ + case oldEntries[oldI].LessThan(newEntries[newI]): + // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. + removed = append(removed, oldEntries[oldI]) + oldI++ + default: + // Reaching the default case implies that the new entry is less than the old entry. + // Since the lists are sorted this means that it only appears in the new list. + added = append(added, newEntries[newI]) + newI++ + } + } +} + +// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolutionV1) LoadStateAtSnapshotForStringTuples( + ctx context.Context, + stateNID types.StateSnapshotNID, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) + if err != nil { + return nil, err + } + return v.loadStateAtSnapshotForNumericTuples(ctx, stateNID, numericTuples) +} + +// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs +// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. +// Returns an error if there was a problem talking to the database. +func (v StateResolutionV1) stringTuplesToNumericTuples( + ctx context.Context, + stringTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateKeyTuple, error) { + eventTypes := make([]string, len(stringTuples)) + stateKeys := make([]string, len(stringTuples)) + for i := range stringTuples { + eventTypes[i] = stringTuples[i].EventType + stateKeys[i] = stringTuples[i].StateKey + } + eventTypes = util.UniqueStrings(eventTypes) + eventTypeMap, err := v.db.EventTypeNIDs(ctx, eventTypes) + if err != nil { + return nil, err + } + stateKeys = util.UniqueStrings(stateKeys) + stateKeyMap, err := v.db.EventStateKeyNIDs(ctx, stateKeys) + if err != nil { + return nil, err + } + + var result []types.StateKeyTuple + for _, stringTuple := range stringTuples { + var numericTuple types.StateKeyTuple + var ok1, ok2 bool + numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] + numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] + // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. + if ok1 && ok2 { + result = append(result, numericTuple) + } + } + + return result, nil +} + +// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples( + ctx context.Context, + stateNID types.StateSnapshotNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := v.db.StateEntriesForTuples( + ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, + ) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // If the block is missing from the map it means that none of its entries matched a requested tuple. + // This can happen if the block doesn't contain an update for one of the requested tuples. + // If none of the requested tuples are in the block then it can be safely skipped. + continue + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAfterEventsForStringTuples loads the state for a list of event type +// and state key pairs after list of events. +// This is used when we only want to load a subset of the room state after a list of events. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolutionV1) LoadStateAfterEventsForStringTuples( + ctx context.Context, + prevStates []types.StateAtEvent, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) + if err != nil { + return nil, err + } + return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) +} + +func (v StateResolutionV1) loadStateAfterEventsForNumericTuples( + ctx context.Context, + prevStates []types.StateAtEvent, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + if len(prevStates) == 1 { + // Fast path for a single event. + prevState := prevStates[0] + result, err := v.loadStateAtSnapshotForNumericTuples( + ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, + ) + if err != nil { + return nil, err + } + if prevState.IsStateEvent() { + // The result is current the state before the requested event. + // We want the state after the requested event. + // If the requested event was a state event then we need to + // update that key in the result. + // If the requested event wasn't a state event then the state after + // it is the same as the state before it. + for i := range result { + if result[i].StateKeyTuple == prevState.StateKeyTuple { + result[i] = prevState.StateEntry + } + } + } + return result, nil + } + + // Slow path for more that one event. + // Load the entire state so that we can do conflict resolution if we need to. + // TODO: The are some optimistations we could do here: + // 1) We only need to do conflict resolution if there is a conflict in the + // requested tuples so we might try loading just those tuples and then + // checking for conflicts. + // 2) When there is a conflict we still only need to load the state + // needed to do conflict resolution which would save us having to load + // the full state. + + // TODO: Add metrics for this as it could take a long time for big rooms + // with large conflicts. + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, prevStates) + if err != nil { + return nil, err + } + + // Sort the full state so we can use it as a map. + sort.Sort(stateEntrySorter(fullState)) + + // Filter the full state down to the required tuples. + var result []types.StateEntry + for _, tuple := range stateKeyTuples { + eventNID, ok := stateEntryMap(fullState).lookup(tuple) + if ok { + result = append(result, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + sort.Sort(stateEntrySorter(result)) + return result, nil +} + +var calculateStateDurations = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_duration_microseconds", + Help: "How long it takes to calculate the state after a list of events", + }, + // Takes two labels: + // algorithm: + // The algorithm used to calculate the state or the step it failed on if it failed. + // Labels starting with "_" are used to indicate when the algorithm fails halfway. + // outcome: + // Whether the state was successfully calculated. + // + // The possible values for algorithm are: + // empty_state -> The list of events was empty so the state is empty. + // no_change -> The state hasn't changed. + // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta + // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts + // while doing so. + // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. + // _load_state_block_nids -> Failed loading the state block nids for a single previous state. + // _load_combined_state -> Failed to load the combined state. + // _resolve_conflicts -> Failed to resolve conflicts. + []string{"algorithm", "outcome"}, +) + +var calculateStatePrevEventLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_prev_event_length", + Help: "The length of the list of events to calculate the state after", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateFullStateLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_full_state_length", + Help: "The length of the full room state.", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateConflictLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_conflict_state_length", + Help: "The length of the conflicted room state.", + }, + []string{"algorithm", "outcome"}, +) + +type calculateStateMetrics struct { + algorithm string + startTime time.Time + prevEventLength int + fullStateLength int + conflictLength int +} + +func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { + var outcome string + if err == nil { + outcome = "success" + } else { + outcome = "failure" + } + endTime := time.Now() + calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( + float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000., + ) + calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.prevEventLength), + ) + calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.fullStateLength), + ) + calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.conflictLength), + ) + return stateNID, err +} + +func init() { + prometheus.MustRegister( + calculateStateDurations, calculateStatePrevEventLength, + calculateStateFullStateLength, calculateStateConflictLength, + ) +} + +// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. +// Stores the snapshot of the state in the database. +// Returns a numeric ID for the snapshot of the state before the event. +func (v StateResolutionV1) CalculateAndStoreStateBeforeEvent( + ctx context.Context, + event gomatrixserverlib.Event, + roomNID types.RoomNID, +) (types.StateSnapshotNID, error) { + // Load the state at the prev events. + prevEventRefs := event.PrevEvents() + prevEventIDs := make([]string, len(prevEventRefs)) + for i := range prevEventRefs { + prevEventIDs[i] = prevEventRefs[i].EventID + } + + prevStates, err := v.db.StateAtEventIDs(ctx, prevEventIDs) + if err != nil { + return 0, err + } + + // The state before this event will be the state after the events that came before it. + return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) +} + +// CalculateAndStoreStateAfterEvents finds the room state after the given events. +// Stores the resulting state in the database and returns a numeric ID for that snapshot. +func (v StateResolutionV1) CalculateAndStoreStateAfterEvents( + ctx context.Context, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, +) (types.StateSnapshotNID, error) { + metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} + + if len(prevStates) == 0 { + // 2) There weren't any prev_events for this event so the state is + // empty. + metrics.algorithm = "empty_state" + return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) + } + + if len(prevStates) == 1 { + prevState := prevStates[0] + if prevState.EventStateKeyNID == 0 { + // 3) None of the previous events were state events and they all + // have the same state, so this event has exactly the same state + // as the previous events. + // This should be the common case. + metrics.algorithm = "no_change" + return metrics.stop(prevState.BeforeStateSnapshotNID, nil) + } + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateBlockNIDLists, err := v.db.StateBlockNIDs( + ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, + ) + if err != nil { + metrics.algorithm = "_load_state_blocks" + return metrics.stop(0, err) + } + stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs + if len(stateBlockNIDs) < maxStateBlockNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + metrics.algorithm = "single_delta" + return metrics.stop(v.db.AddState( + ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + )) + } + // If there are too many deltas then we need to calculate the full state + // So fall through to calculateAndStoreStateAfterManyEvents + } + + return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) +} + +// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. +// Increasing this number means that we can encode more of the state changes as simple deltas which means that +// we need fewer entries in the state data table. However making this number bigger will increase the size of +// the rows in the state table itself and will require more index lookups when retrieving a snapshot. +// TODO: Tune this to get the right balance between size and lookup performance. +const maxStateBlockNIDs = 64 + +// calculateAndStoreStateAfterManyEvents finds the room state after the given events. +// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. +// Stores the resulting state and returns a numeric ID for the snapshot. +func (v StateResolutionV1) calculateAndStoreStateAfterManyEvents( + ctx context.Context, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, + metrics calculateStateMetrics, +) (types.StateSnapshotNID, error) { + + state, algorithm, conflictLength, err := + v.calculateStateAfterManyEvents(ctx, prevStates) + metrics.algorithm = algorithm + if err != nil { + return metrics.stop(0, err) + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + metrics.conflictLength = conflictLength + metrics.fullStateLength = len(state) + return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) +} + +func (v StateResolutionV1) calculateStateAfterManyEvents( + ctx context.Context, prevStates []types.StateAtEvent, +) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + var combined []types.StateEntry + // Conflict resolution. + // First stage: load the state after each of the prev events. + combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates) + if err != nil { + algorithm = "_load_combined_state" + return + } + + // Collect all the entries with the same type and key together. + // We don't care about the order here because the conflict resolution + // algorithm doesn't depend on the order of the prev events. + // Remove duplicate entires. + combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] + + // Find the conflicts + conflicts := findDuplicateStateKeys(combined) + + if len(conflicts) > 0 { + conflictLength = len(conflicts) + + // 5) There are conflicting state events, for each conflict workout + // what the appropriate state event is. + + // Work out which entries aren't conflicted. + var notConflicted []types.StateEntry + for _, entry := range combined { + if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + notConflicted = append(notConflicted, entry) + } + } + + var resolved []types.StateEntry + resolved, err = v.resolveConflicts(ctx, notConflicted, conflicts) + if err != nil { + algorithm = "_resolve_conflicts" + return + } + algorithm = "full_state_with_conflicts" + state = resolved + } else { + algorithm = "full_state_no_conflicts" + // 6) There weren't any conflicts + state = combined + } + return +} + +// resolveConflicts resolves a list of conflicted state entries. It takes two lists. +// The first is a list of all state entries that are not conflicted. +// The second is a list of all state entries that are conflicted +// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. +// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. +// The returned list is sorted by state key tuple. +// Returns an error if there was a problem talking to the database. +func (v StateResolutionV1) resolveConflicts( + ctx context.Context, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { + + // Load the conflicted events + conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) + if err != nil { + return nil, err + } + + // Work out which auth events we need to load. + needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) + + // Find the numeric IDs for the necessary state keys. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) + if err != nil { + return nil, err + } + + // Load the necessary auth events. + tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) + var authEntries []types.StateEntry + for _, tuple := range tuplesNeeded { + if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + authEntries = append(authEntries, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + authEvents, _, err := v.loadStateEvents(ctx, authEntries) + if err != nil { + return nil, err + } + + // Resolve the conflicts. + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + + // Map from the full events back to numeric state entries. + for _, resolvedEvent := range resolvedEvents { + entry, ok := eventIDMap[resolvedEvent.EventID()] + if !ok { + panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) + } + notConflicted = append(notConflicted, entry) + } + + // Sort the result so it can be searched. + sort.Sort(stateEntrySorter(notConflicted)) + return notConflicted, nil +} + +// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. +func (v StateResolutionV1) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { + var keyTuples []types.StateKeyTuple + if stateNeeded.Create { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.PowerLevels { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.JoinRules { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + for _, member := range stateNeeded.Member { + stateKeyNID, ok := stateKeyNIDMap[member] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + for _, token := range stateNeeded.ThirdPartyInvite { + stateKeyNID, ok := stateKeyNIDMap[token] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + return keyTuples +} + +// loadStateEvents loads the matrix events for a list of state entries. +// Returns a list of state events in no particular order and a map from string event ID back to state entry. +// The map can be used to recover which numeric state entry a given event is for. +// Returns an error if there was a problem talking to the database. +func (v StateResolutionV1) loadStateEvents( + ctx context.Context, entries []types.StateEntry, +) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventNIDs := make([]types.EventNID, len(entries)) + for i := range entries { + eventNIDs[i] = entries[i].EventNID + } + events, err := v.db.Events(ctx, eventNIDs) + if err != nil { + return nil, nil, err + } + eventIDMap := map[string]types.StateEntry{} + result := make([]gomatrixserverlib.Event, len(entries)) + for i := range entries { + event, ok := eventMap(events).lookup(entries[i].EventNID) + if !ok { + panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) + } + result[i] = event.Event + eventIDMap[event.Event.EventID()] = entries[i] + } + return result, eventIDMap, nil +} + +// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. +// Returns a sorted list of those state entries. +func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + // j is the starting index of a block of entries with the same state key tuple. + j := 0 + for i := 1; i < len(a); i++ { + // Check if the state key tuple matches the start of the block + if a[j].StateKeyTuple != a[i].StateKeyTuple { + // If the state key tuple is different then we've reached the end of a block of duplicates. + // Check if the size of the block is bigger than one. + // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result + if j+1 != i { + // Add the block to the result. + result = append(result, a[j:i]...) + } + // Start a new block for the next state key tuple. + j = i + } + } + // Check if the last block with the same state key tuple had more than one event in it. + if j+1 != len(a) { + result = append(result, a[j:]...) + } + return result +} + +type stateEntrySorter []types.StateEntry + +func (s stateEntrySorter) Len() int { return len(s) } +func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateBlockNIDListMap []types.StateBlockNIDList + +func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateNIDSorter []types.StateSnapshotNID + +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + return nids[:util.SortAndUnique(stateNIDSorter(nids))] +} + +type stateBlockNIDSorter []types.StateBlockNID + +func (s stateBlockNIDSorter) Len() int { return len(s) } +func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type stateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type eventMap []types.Event + +// lookup an entry in the event map. +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return +} diff --git a/roomserver/state/state_test.go b/roomserver/state/v1/state_test.go similarity index 94% rename from roomserver/state/state_test.go rename to roomserver/state/v1/state_test.go index 67af18671..4dc7e52ec 100644 --- a/roomserver/state/state_test.go +++ b/roomserver/state/v1/state_test.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package state +package v1 import ( "testing" diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 415fb84eb..0b7ef6aa7 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON( } result.EventNID = types.EventNID(eventNID) } - return results[:i], nil + return results[:i], rows.Err() } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index c3aaa498e..cbc29a69d 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( } result[stateKey] = types.EventStateKeyNID(stateKeyNID) } - return result, nil + return result, rows.Err() } func (s *eventStateKeyStatements) bulkSelectEventStateKey( @@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( } result[types.EventStateKeyNID(stateKeyNID)] = stateKey } - return result, nil + return result, rows.Err() } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 1ec2e7cde..faa887545 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID( } result[eventType] = types.EventTypeNID(eventTypeNID) } - return result, nil + return result, rows.Err() } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1e8a5665b..d9b269bc8 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID( return nil, err } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventIDs) { // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // We don't know which ones were missing because we don't return the string IDs in the query. @@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID( fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), ) } - return results, err + return results, nil } // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. @@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID( ) } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventIDs) { return nil, types.MissingEventError( fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), ) } - return results, err + return results, nil } func (s *eventStatements) updateEventState( @@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( result.EventID = eventID result.EventSHA256 = eventSHA256 } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) } @@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference( return nil, err } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) } @@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ } results[types.EventNID(eventNID)] = eventID } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) } @@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str } results[eventID] = types.EventNID(eventNID) } - return results, nil + return results, rows.Err() } func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 43cd5ba09..603fed31b 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) updateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventIDs []string, err error) { +) ([]string, error) { stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err } - defer (func() { err = rows.Close() })() + defer rows.Close() // nolint: errcheck + + var eventIDs []string for rows.Next() { var inviteEventID string - if err := rows.Scan(&inviteEventID); err != nil { + if err = rows.Scan(&inviteEventID); err != nil { return nil, err } eventIDs = append(eventIDs, inviteEventID) } - return + return eventIDs, rows.Err() } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs @@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom( } result = append(result, types.EventStateKeyNID(senderUserNID)) } - return result, nil + return result, rows.Err() } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 9f41fd67b..70032fd1e 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom( if err != nil { return } + defer rows.Close() // nolint: errcheck for rows.Next() { var eNID types.EventNID @@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom( } eventNIDs = append(eventNIDs, eNID) } - return + return eventNIDs, rows.Err() } + func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, roomNID types.RoomNID, membership membershipState, @@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( if err != nil { return } + defer rows.Close() // nolint: errcheck for rows.Next() { var eNID types.EventNID @@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( } eventNIDs = append(eventNIDs, eNID) } - return + return eventNIDs, rows.Err() } func (s *membershipStatements) updateMembership( diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index ad1b560c2..6de898c41 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias( func (s *roomAliasesStatements) selectAliasesFromRoomID( ctx context.Context, roomID string, -) (aliases []string, err error) { - aliases = []string{} +) ([]string, error) { rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { - return + return nil, err } + defer rows.Close() // nolint: errcheck + var aliases []string for rows.Next() { var alias string if err = rows.Scan(&alias); err != nil { - return + return nil, err } aliases = append(aliases, alias) } - - return + return aliases, rows.Err() } func (s *roomAliasesStatements) selectCreatorIDFromAlias( diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index ccc201b18..edd15a338 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -39,7 +39,10 @@ CREATE TABLE IF NOT EXISTS roomserver_rooms ( last_event_sent_nid BIGINT NOT NULL DEFAULT 0, -- The state of the room after the current set of latest events. -- This will be 0 if there are no latest events in the room. - state_snapshot_nid BIGINT NOT NULL DEFAULT 0 + state_snapshot_nid BIGINT NOT NULL DEFAULT 0, + -- The version of the room, which will assist in determining the state resolution + -- algorithm, event ID format, etc. + room_version BIGINT NOT NULL DEFAULT 1 ); ` @@ -61,12 +64,16 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +87,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -154,3 +162,12 @@ func (s *roomStatements) updateLatestEventNIDs( ) return err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (int64, error) { + var roomVersion int64 + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 15e69cc98..e6f4f7fe9 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( eventNID int64 entry types.StateEntry ) - if err := rows.Scan( + if err = rows.Scan( &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, ); err != nil { return nil, err @@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( } current.StateEntries = append(current.StateEntries, entry) } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(stateBlockNIDs) { return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) } - return results, nil + return results, err } func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( @@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( if current.StateEntries != nil { results = append(results, current) } - return results, nil + return results, rows.Err() } func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 76f1d2b66..a1f26e228 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( for ; rows.Next(); i++ { result := &results[i] var stateBlockNIDs pq.Int64Array - if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { + if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { return nil, err } result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) @@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(stateNIDs) { return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 93450e5a5..77a792d68 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -697,6 +697,14 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type return d.Events(ctx, nids) } +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomNID types.RoomNID, +) (int64, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index fb5ff219c..bf237728d 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -30,7 +30,8 @@ const roomsSchema = ` room_id TEXT NOT NULL UNIQUE, latest_event_nids TEXT NOT NULL DEFAULT '{}', last_event_sent_nid INTEGER NOT NULL DEFAULT 0, - state_snapshot_nid INTEGER NOT NULL DEFAULT 0 + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + room_version INTEGER NOT NULL DEFAULT 1 ); ` @@ -52,12 +53,16 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -71,6 +76,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -148,3 +154,12 @@ func (s *roomStatements) updateLatestEventNIDs( ) return err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (int64, error) { + var roomVersion int64 + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6ed0789be..b25fd542f 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -831,6 +831,14 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type return d.Events(ctx, nids) } +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomNID types.RoomNID, +) (int64, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 1516e2ad7..551d97cd1 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -19,7 +19,7 @@ import ( "net/url" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/state" + statedb "github.com/matrix-org/dendrite/roomserver/state/database" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/types" @@ -27,7 +27,7 @@ import ( ) type Database interface { - state.RoomStateDatabase + statedb.RoomStateDatabase StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) @@ -48,6 +48,7 @@ type Database interface { GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + GetRoomVersionForRoom(ctx context.Context, roomNID types.RoomNID) (int64, error) } // NewPublicRoomsServerDatabase opens a database connection. diff --git a/roomserver/version/version.go b/roomserver/version/version.go new file mode 100644 index 000000000..0943e3843 --- /dev/null +++ b/roomserver/version/version.go @@ -0,0 +1,112 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +import ( + "errors" + + "github.com/matrix-org/dendrite/roomserver/state" +) + +type RoomVersionID int +type EventFormatID int + +const ( + RoomVersionV1 RoomVersionID = iota + 1 + RoomVersionV2 + RoomVersionV3 + RoomVersionV4 + RoomVersionV5 +) + +const ( + EventFormatV1 EventFormatID = iota + 1 // original event ID formatting + EventFormatV2 // event ID is event hash + EventFormatV3 // event ID is URL-safe base64 event hash +) + +type RoomVersionDescription struct { + Supported bool + Stable bool + StateResolution state.StateResolutionVersion + EventFormat EventFormatID + EnforceSigningKeyValidity bool +} + +var roomVersions = map[RoomVersionID]RoomVersionDescription{ + RoomVersionV1: RoomVersionDescription{ + Supported: true, + Stable: true, + StateResolution: state.StateResolutionAlgorithmV1, + EventFormat: EventFormatV1, + EnforceSigningKeyValidity: false, + }, + RoomVersionV2: RoomVersionDescription{ + Supported: false, + Stable: true, + StateResolution: state.StateResolutionAlgorithmV2, + EventFormat: EventFormatV1, + EnforceSigningKeyValidity: false, + }, + RoomVersionV3: RoomVersionDescription{ + Supported: false, + Stable: true, + StateResolution: state.StateResolutionAlgorithmV2, + EventFormat: EventFormatV2, + EnforceSigningKeyValidity: false, + }, + RoomVersionV4: RoomVersionDescription{ + Supported: false, + Stable: true, + StateResolution: state.StateResolutionAlgorithmV2, + EventFormat: EventFormatV3, + EnforceSigningKeyValidity: false, + }, + RoomVersionV5: RoomVersionDescription{ + Supported: false, + Stable: true, + StateResolution: state.StateResolutionAlgorithmV2, + EventFormat: EventFormatV3, + EnforceSigningKeyValidity: true, + }, +} + +func GetDefaultRoomVersion() RoomVersionID { + return RoomVersionV1 +} + +func GetRoomVersions() map[RoomVersionID]RoomVersionDescription { + return roomVersions +} + +func GetSupportedRoomVersions() map[RoomVersionID]RoomVersionDescription { + versions := make(map[RoomVersionID]RoomVersionDescription) + for id, version := range GetRoomVersions() { + if version.Supported { + versions[id] = version + } + } + return versions +} + +func GetSupportedRoomVersion(version RoomVersionID) (desc RoomVersionDescription, err error) { + if version, ok := roomVersions[version]; ok { + desc = version + } + if !desc.Supported { + err = errors.New("unsupported room version") + } + return +} diff --git a/syncapi/api/query.go b/syncapi/api/query.go new file mode 100644 index 000000000..2993829e0 --- /dev/null +++ b/syncapi/api/query.go @@ -0,0 +1,123 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "net/http" + + commonHTTP "github.com/matrix-org/dendrite/common/http" + "github.com/matrix-org/util" + opentracing "github.com/opentracing/opentracing-go" +) + +const ( + SyncAPIQuerySyncPath = "/api/syncapi/querySync" + SyncAPIQueryStatePath = "/api/syncapi/queryState" + SyncAPIQueryStateTypePath = "/api/syncapi/queryStateType" + SyncAPIQueryMessagesPath = "/api/syncapi/queryMessages" +) + +func NewSyncQueryAPIHTTP(syncapiURL string, httpClient *http.Client) SyncQueryAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpSyncQueryAPI{syncapiURL, httpClient} +} + +type httpSyncQueryAPI struct { + syncapiURL string + httpClient *http.Client +} + +type SyncQueryAPI interface { + QuerySync(ctx context.Context, request *QuerySyncRequest, response *QuerySyncResponse) error + QueryState(ctx context.Context, request *QueryStateRequest, response *QueryStateResponse) error + QueryStateType(ctx context.Context, request *QueryStateTypeRequest, response *QueryStateTypeResponse) error + QueryMessages(ctx context.Context, request *QueryMessagesRequest, response *QueryMessagesResponse) error +} + +type QuerySyncRequest struct{} + +type QueryStateRequest struct { + RoomID string +} + +type QueryStateTypeRequest struct { + RoomID string + EventType string + StateKey string +} + +type QueryMessagesRequest struct { + RoomID string +} + +type QuerySyncResponse util.JSONResponse +type QueryStateResponse util.JSONResponse +type QueryStateTypeResponse util.JSONResponse +type QueryMessagesResponse util.JSONResponse + +// QueryLatestEventsAndState implements SyncQueryAPI +func (h *httpSyncQueryAPI) QuerySync( + ctx context.Context, + request *QuerySyncRequest, + response *QuerySyncResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySync") + defer span.Finish() + + apiURL := h.syncapiURL + SyncAPIQuerySyncPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryStateAfterEvents implements SyncQueryAPI +func (h *httpSyncQueryAPI) QueryState( + ctx context.Context, + request *QueryStateRequest, + response *QueryStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryState") + defer span.Finish() + + apiURL := h.syncapiURL + SyncAPIQueryStatePath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryEventsByID implements SyncQueryAPI +func (h *httpSyncQueryAPI) QueryStateType( + ctx context.Context, + request *QueryStateTypeRequest, + response *QueryStateTypeResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateType") + defer span.Finish() + + apiURL := h.syncapiURL + SyncAPIQueryStateTypePath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMembershipForUser implements SyncQueryAPI +func (h *httpSyncQueryAPI) QueryMessages( + ctx context.Context, + request *QueryMessagesRequest, + response *QueryMessagesResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMessages") + defer span.Finish() + + apiURL := h.syncapiURL + SyncAPIQueryMessagesPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/syncapi/routing/state.go b/syncapi/routing/state.go index dbee267d6..cf67f7522 100644 --- a/syncapi/routing/state.go +++ b/syncapi/routing/state.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -45,10 +44,10 @@ func OnIncomingStateRequest(req *http.Request, db storage.Database, roomID strin // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) - stateFilterPart := gomatrix.DefaultFilterPart() - // TODO: stateFilterPart should not limit the number of state events (or only limits abusive number of events) + stateFilter := gomatrixserverlib.DefaultStateFilter() + // TODO: stateFilter should not limit the number of state events (or only limits abusive number of events) - stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID, &stateFilterPart) + stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID, &stateFilter) if err != nil { return httputil.LogThenError(req, err) } diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 94e6ac41c..d1811aa66 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -99,7 +99,7 @@ func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrix.FilterPart, + accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) @@ -111,13 +111,14 @@ func (s *accountDataStatements) selectAccountDataInRange( } rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos, - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)), - accountDataFilterPart.Limit, + pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), + accountDataEventFilter.Limit, ) if err != nil { return } + defer rows.Close() // nolint: errcheck for rows.Next() { var dataType string @@ -133,8 +134,7 @@ func (s *accountDataStatements) selectAccountDataInRange( data[roomID] = []string{dataType} } } - - return + return data, rows.Err() } func (s *accountDataStatements) selectMaxAccountDataID( diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backward_extremities_table.go index 1489f7f91..d63c546e3 100644 --- a/syncapi/storage/postgres/backward_extremities_table.go +++ b/syncapi/storage/postgres/backward_extremities_table.go @@ -91,6 +91,7 @@ func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( if err != nil { return } + defer rows.Close() // nolint: errcheck for rows.Next() { var eID string @@ -101,7 +102,7 @@ func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( eventIDs = append(eventIDs, eID) } - return + return eventIDs, rows.Err() } func (s *backwardExtremitiesStatements) isBackwardExtremity( diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 816cbb44a..6f5c1e803 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -154,7 +153,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers( users = append(users, userID) result[roomID] = users } - return result, nil + return result, rows.Err() } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. @@ -179,22 +178,22 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( } result = append(result, roomID) } - return result, nil + return result, rows.Err() } // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) selectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrix.FilterPart, + stateFilter *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.Event, error) { stmt := common.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, - pq.StringArray(stateFilterPart.Senders), - pq.StringArray(stateFilterPart.NotSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), - stateFilterPart.ContainsURL, - stateFilterPart.Limit, + pq.StringArray(stateFilter.Senders), + pq.StringArray(stateFilter.NotSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), + stateFilter.ContainsURL, + stateFilter.Limit, ) if err != nil { return nil, err @@ -267,7 +266,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) { } result = append(result, ev) } - return result, nil + return result, rows.Err() } func (s *currentRoomStateStatements) selectStateEvent( diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index ca4bbeb5c..2cb8fb199 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -133,7 +133,7 @@ func (s *inviteEventsStatements) selectInviteEventsInRange( result[roomID] = event } - return result, nil + return result, rows.Err() } func (s *inviteEventsStatements) selectMaxInviteID( diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 6d213a57e..2db46c5db 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -154,22 +153,23 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, - stateFilterPart *gomatrix.FilterPart, + stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( ctx, oldPos, newPos, - pq.StringArray(stateFilterPart.Senders), - pq.StringArray(stateFilterPart.NotSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), - stateFilterPart.ContainsURL, - stateFilterPart.Limit, + pq.StringArray(stateFilter.Senders), + pq.StringArray(stateFilter.NotSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), + stateFilter.ContainsURL, + stateFilter.Limit, ) if err != nil { return nil, nil, err } + defer rows.Close() // nolint: errcheck // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -226,7 +226,7 @@ func (s *outputRoomEventsStatements) selectStateInRange( } } - return stateNeeded, eventIDToEvent, nil + return stateNeeded, eventIDToEvent, rows.Err() } // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, @@ -392,5 +392,5 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { ExcludeFromSync: excludeFromSync, }) } - return result, nil + return result, rows.Err() } diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 793d1e236..78a381da9 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -134,6 +134,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( } else if err != nil { return } + defer rows.Close() // nolint: errcheck // Return the IDs. var eventID string @@ -144,7 +145,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( eventIDs = append(eventIDs, eventID) } - return + return eventIDs, rows.Err() } // selectPositionInTopology returns the position of a given event in the @@ -176,6 +177,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( } else if err != nil { return } + defer rows.Close() // nolint: errcheck // Return the IDs. var eventID string for rows.Next() { @@ -184,5 +186,5 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( } eventIDs = append(eventIDs, eventID) } - return + return eventIDs, rows.Err() } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index f391c5784..aec37185d 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -26,7 +26,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrix" // Import the postgres database driver. _ "github.com/lib/pq" @@ -237,10 +236,10 @@ func (d *SyncServerDatasource) GetStateEvent( // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. func (d *SyncServerDatasource) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart, + ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.Event, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) + stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilter) return err }) return @@ -422,7 +421,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( var succeeded bool defer common.EndTransaction(txn, &succeeded) - stateFilterPart := gomatrix.DefaultFilterPart() // TODO: use filter provided in request + stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Work out which rooms to return in the response. This is done by getting not only the currently // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. @@ -432,11 +431,11 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( var joinedRoomIDs []string if !wantFullState { deltas, joinedRoomIDs, err = d.getStateDeltas( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, + ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, ) } else { deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, + ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, ) } if err != nil { @@ -587,12 +586,12 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( return } - stateFilterPart := gomatrix.DefaultFilterPart() // TODO: use filter provided in request + stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { var stateEvents []gomatrixserverlib.Event - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) + stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilter) if err != nil { return } @@ -681,7 +680,7 @@ var txReadOnlySnapshot = sql.TxOptions{ // If there was an issue with the retrieval, returns an error func (d *SyncServerDatasource) GetAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrix.FilterPart, + accountDataFilterPart *gomatrixserverlib.EventFilter, ) (map[string][]string, error) { return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) } @@ -931,7 +930,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrix.FilterPart, + stateFilter *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -944,7 +943,7 @@ func (d *SyncServerDatasource) getStateDeltas( var deltas []stateDelta // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) + stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) if err != nil { return nil, nil, err } @@ -964,7 +963,7 @@ func (d *SyncServerDatasource) getStateDeltas( if membership == gomatrixserverlib.Join { // send full room state down instead of a delta var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) + s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) if err != nil { return nil, nil, err } @@ -1006,7 +1005,7 @@ func (d *SyncServerDatasource) getStateDeltas( func (d *SyncServerDatasource) getStateDeltasForFullStateSync( ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrix.FilterPart, + stateFilter *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { @@ -1018,7 +1017,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( // Add full states for all joined rooms for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart) + s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) if stateErr != nil { return nil, nil, stateErr } @@ -1030,7 +1029,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( } // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) + stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) if err != nil { return nil, nil, err } @@ -1061,9 +1060,9 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrix.FilterPart, + stateFilter *gomatrixserverlib.StateFilter, ) ([]types.StreamEvent, error) { - allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) + allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilter) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 8ebf79bdd..3274e66ea 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -92,7 +92,7 @@ func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrix.FilterPart, + accountDataFilterPart *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 2145dea29..4ce946667 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -175,7 +174,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) selectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrix.FilterPart, + stateFilterPart *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.Event, error) { stmt := common.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index ddc9375ad..c0091a38c 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -153,7 +152,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDState // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, - stateFilterPart *gomatrix.FilterPart, + stateFilterPart *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index d5875c3b4..5517c3bc2 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrix" // Import the postgres database driver. _ "github.com/lib/pq" @@ -262,7 +261,7 @@ func (d *SyncServerDatasource) GetStateEvent( // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. func (d *SyncServerDatasource) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart, + ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.Event, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) @@ -447,7 +446,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( var succeeded bool defer common.EndTransaction(txn, &succeeded) - stateFilterPart := gomatrix.DefaultFilterPart() // TODO: use filter provided in request + stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Work out which rooms to return in the response. This is done by getting not only the currently // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. @@ -613,7 +612,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } fmt.Println("Joined rooms:", joinedRoomIDs) - stateFilterPart := gomatrix.DefaultFilterPart() // TODO: use filter provided in request + stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { @@ -716,7 +715,7 @@ var txReadOnlySnapshot = sql.TxOptions{ // If there was an issue with the retrieval, returns an error func (d *SyncServerDatasource) GetAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrix.FilterPart, + accountDataFilterPart *gomatrixserverlib.EventFilter, ) (map[string][]string, error) { return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) } @@ -972,7 +971,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrix.FilterPart, + stateFilterPart *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -1047,7 +1046,7 @@ func (d *SyncServerDatasource) getStateDeltas( func (d *SyncServerDatasource) getStateDeltasForFullStateSync( ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrix.FilterPart, + stateFilterPart *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { @@ -1102,7 +1101,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrix.FilterPart, + stateFilterPart *gomatrixserverlib.StateFilter, ) ([]types.StreamEvent, error) { allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) if err != nil { diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index d6ec79ad9..c87024b29 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -26,7 +26,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/typingserver/cache" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -36,11 +35,11 @@ type Database interface { Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) - GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.Event, err error) SyncPosition(ctx context.Context) (types.PaginationToken, error) IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) - GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) + GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error) RetireInviteEvent(ctx context.Context, inviteEventID string) error diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 3daf21028..22bd239fc 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -142,14 +141,14 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Pagin return } - accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead - res, err = rp.appendAccountData(res, req.device.UserID, req, int64(latestPos.PDUPosition), &accountDataFilter) + accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) return } func (rp *RequestPool) appendAccountData( - data *types.Response, userID string, req syncRequest, currentPos int64, - accountDataFilter *gomatrix.FilterPart, + data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, + accountDataFilter *gomatrixserverlib.EventFilter, ) (*types.Response, error) { // TODO: Account data doesn't have a sync position of its own, meaning that // account data might be sent multiple time to the client if multiple account diff --git a/sytest-blacklist b/sytest-blacklist index dd5e2cd5c..2df2b3a86 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -7,6 +7,9 @@ POST /login can log in as a user with just the local part of the id # Blacklisted due to flakiness avatar_url updates affect room member events +# Blacklisted due to flakiness +displayname updates affect room member events + # Blacklisted due to flakiness Room members can override their displayname on a room-specific basis @@ -16,3 +19,12 @@ Alias creators can delete alias with no ops # Blacklisted because matrix-org/dendrite#847 might have broken it but we're not # really sure and we need it pretty badly anyway Real non-joined users can get individual state for world_readable rooms after leaving + +# Blacklisted until matrix-org/dendrite#862 is reverted due to Riot bug +Latest account data appears in v2 /sync + +# Blacklisted due to flakiness +Outbound federation can backfill events + +# Blacklisted due to alias work on Synapse +Alias creators can delete canonical alias with no ops diff --git a/sytest-whitelist b/sytest-whitelist index 4c333d3cb..47fd58286 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -67,7 +67,6 @@ Can get rooms/{roomId}/members for a departed room (SPEC-216) 3pid invite join valid signature but revoked keys are rejected 3pid invite join valid signature but unreachable ID server are rejected Room members can join a room with an overridden displayname -displayname updates affect room member events Real non-joined user cannot call /events on shared room Real non-joined user cannot call /events on invited room Real non-joined user cannot call /events on joined room @@ -82,6 +81,7 @@ Can't forget room you're still in Can get rooms/{roomId}/members Can create filter Can download filter +Lazy loading parameters in the filter are strictly boolean Can sync Can sync a joined room Newly joined room is included in an incremental sync @@ -113,7 +113,7 @@ User can invite local user to room with version 4 Should reject keys claiming to belong to a different user Can add account data Can add account data to room -Latest account data appears in v2 /sync +#Latest account data appears in v2 /sync New account data appears in incremental v2 /sync Checking local federation server Inbound federation can query profile data @@ -210,3 +210,24 @@ Message history can be paginated Getting messages going forward is limited for a departed room (SPEC-216) m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users Backfill works correctly with history visibility set to joined +Guest user cannot call /events globally +Guest users can join guest_access rooms +Guest user can set display names +Guest user cannot upgrade other users +m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users +Guest non-joined user cannot call /events on shared room +Guest non-joined user cannot call /events on invited room +Guest non-joined user cannot call /events on joined room +Guest non-joined user cannot call /events on default room +Guest non-joined users can get state for world_readable rooms +Guest non-joined users can get individual state for world_readable rooms +Guest non-joined users cannot room initalSync for non-world_readable rooms +Guest non-joined users can get individual state for world_readable rooms after leaving +Guest non-joined users cannot send messages to guest_access rooms if not joined +Guest users can sync from world_readable guest_access rooms if joined +Guest users can sync from default guest_access rooms if joined +Real non-joined users cannot room initalSync for non-world_readable rooms +Push rules come down in an initial /sync +Regular users can add and delete aliases in the default room configuration +Regular users can add and delete aliases when m.room.aliases is restricted +GET /r0/capabilities is not public