diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1204582e2..8014e9414 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,6 @@ -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: `Your Name ` diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a4ef8b395..de6c79ddc 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -2,9 +2,9 @@ name: "CodeQL" on: push: - branches: [master] + branches: [main] pull_request: - branches: [master] + branches: [main] jobs: analyze: @@ -14,21 +14,21 @@ jobs: strategy: fail-fast: false matrix: - language: ['go'] + language: ["go"] steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - fetch-depth: 2 + - name: Checkout repository + uses: actions/checkout@v2 + with: + fetch-depth: 2 - - run: git checkout HEAD^2 - if: ${{ github.event_name == 'pull_request' }} + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} - - name: Initialize CodeQL - uses: github/codeql-action/init@v1 - with: - languages: ${{ matrix.language }} + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad5a2660c..124940f71 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,7 +2,7 @@ name: Tests on: push: - branches: [ 'master' ] + branches: ["main"] pull_request: concurrency: @@ -33,7 +33,7 @@ jobs: path: dendrite # Attempt to check out the same branch of Complement as the PR. If it - # doesn't exist, fallback to master. + # doesn't exist, fallback to main. - name: Checkout complement shell: bash run: | @@ -63,9 +63,9 @@ jobs: # Run Complement - run: | set -o pipefail && - go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: COMPLEMENT_BASE_IMAGE: complement-dendrite:latest - working-directory: complement \ No newline at end of file + working-directory: complement diff --git a/.gitignore b/.gitignore index dbc84edb1..092f4501c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ /vendor/bin /docker/build /logs +/jetstream # Architecture specific extensions/prefixes *.[568vq] diff --git a/CHANGES.md b/CHANGES.md index 94edc6288..4df8e869a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,79 @@ # Changelog +## Dendrite 0.6.3 (2022-02-10) + +### Features + +* Initial support for `m.login.token` +* A number of regressions from earlier v0.6.x versions should now be corrected + +### Fixes + +* Missing state is now correctly retrieved in cases where a gap in the timeline was closed but some of those events were missing state snapshots, which should help to unstick slow or broken rooms +* Fixed a transaction issue where inserting events into the database could deadlock, which should stop rooms from getting stuck +* Fixed a problem where rejected events could result in rolled back database transactions +* Avoided a potential race condition on fetching latest events by using the room updater instead +* Processing events from `/get_missing_events` will no longer result in potential recursion +* Federation events are now correctly generated for updated self-signing keys and signed devices +* Rejected events can now be un-rejected if they are reprocessed and all of the correct conditions are met +* Fetching missing auth events will no longer error as long as all needed events for auth were satisfied +* Users can now correctly forget rooms if they were not a member of the room + +## Dendrite 0.6.2 (2022-02-04) + +### Fixes + +* Resolves an issue where the key change consumer in the keyserver could consume extreme amounts of CPU + +## Dendrite 0.6.1 (2022-02-04) + +### Features + +* Roomserver inputs now take place with full transactional isolation in PostgreSQL deployments +* Pull consumers are now used instead of push consumers when retrieving messages from NATS to better guarantee ordering and to reduce redelivery of duplicate messages +* Further logging tweaks, particularly when joining rooms +* Improved calculation of servers in the room, when checking for missing auth/prev events or state +* Dendrite will now skip dead servers more quickly when federating by reducing the TCP dial timeout +* The key change consumers have now been converted to use native NATS code rather than a wrapper +* Go 1.16 is now the minimum supported version for Dendrite + +### Fixes + +* Local clients should now be notified correctly of invites +* The roomserver input API now has more time to process events, particularly when fetching missing events or state, which should fix a number of errors from expired contexts +* Fixed a panic that could happen due to a closed channel in the roomserver input API +* Logging in with uppercase usernames from old installations is now supported again (contributed by [hoernschen](https://github.com/hoernschen)) +* Federated room joins now have more time to complete and should not fail due to expired contexts +* Events that were sent to the roomserver along with a complete state snapshot are now persisted with the correct state, even if they were rejected or soft-failed + +## Dendrite 0.6.0 (2022-01-28) + +### Features + +* NATS JetStream is now used instead of Kafka and Naffka + * For monolith deployments, a built-in NATS Server is embedded into Dendrite or a standalone NATS Server deployment can be optionally used instead + * For polylith deployments, a standalone NATS Server deployment is required + * Requires the version 2 configuration file — please see the new `dendrite-config.yaml` sample config file + * Kafka and Naffka are no longer supported as of this release +* The roomserver is now responsible for fetching missing events and state instead of the federation API + * Removes a number of race conditions between the federation API and roomserver, which reduces duplicate work and overall lowers CPU usage +* The roomserver input API is now strictly ordered with support for asynchronous requests, smoothing out incoming federation significantly +* Consolidated the federation API, federation sender and signing key server into a single component + * If multiple databases are used, tables for the federation sender and signing key server should be merged into the federation API database (table names have not changed) +* Device list synchronisation is now database-backed rather than using the now-removed Kafka logs + +### Fixes + +* The code for fetching missing events and state now correctly identifies when gaps in history have been closed, so federation traffic will consume less CPU and memory than before +* The stream position is now correctly advanced when typing notifications time out in the sync API +* Event NIDs are now correctly returned when persisting events in the roomserver in SQLite mode + * The built-in SQLite was updated to version 3.37.0 as a result +* The `/event_auth` endpoint now strictly returns the auth chain for the requested event without loading the room state, which should reduce spikes in memory usage +* Filters are now correctly sent when using federated public room directories (contributed by [S7evinK](https://github.com/S7evinK)) +* Login usernames are now squashed to lower-case (contributed by [BernardZhao](https://github.com/BernardZhao)) +* The logs should no longer be flooded with `Failed to get server ACLs for room` warnings at startup +* Backfilling will now attempt federation as a last resort when trying to retrieve missing events from the database fails + ## Dendrite 0.5.1 (2021-11-16) ### Features diff --git a/README.md b/README.md index 3ec9f0296..a077788cf 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or j ## Requirements -To build Dendrite, you will need Go 1.15 or later. +To build Dendrite, you will need Go 1.16 or later. For a usable federating Dendrite deployment, you will also need: - A domain name (or subdomain) diff --git a/appservice/api/query.go b/appservice/api/query.go index cd74d866c..e53ad4259 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -23,7 +23,7 @@ import ( "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +85,7 @@ func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) (*authtypes.Profile, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/appservice/appservice.go b/appservice/appservice.go index 924a609ea..b33d7b701 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" @@ -34,7 +36,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -58,7 +59,7 @@ func NewInternalAPI( }, }, } - js, _, _ := jetstream.Prepare(&base.Cfg.Global.JetStream) + js := jetstream.Prepare(&base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) @@ -121,7 +122,7 @@ func generateAppServiceAccount( ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ - AccountType: userapi.AccountTypeUser, + AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 8aea5c347..7b59e3704 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -34,7 +34,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string asDB storage.Database rsAPI api.RoomserverInternalAPI @@ -66,37 +66,37 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - if output.Type != api.OutputTypeNewRoomEvent { - return true - } - - events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) - - // Send event to any relevant application services - if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { - log.WithError(err).Errorf("roomserver output log: filter error") - return true - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + if output.Type != api.OutputTypeNewRoomEvent { + return true + } + + events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) + + // Send event to any relevant application services + if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { + log.WithError(err).Errorf("roomserver output log: filter error") + return true + } + + return true } // filterRoomserverEvents takes in events and decides whether any of them need diff --git a/build.sh b/build.sh index 8196fc653..700e6434f 100755 --- a/build.sh +++ b/build.sh @@ -7,7 +7,7 @@ if [ -d ".git" ] then export BUILD=`git rev-parse --short HEAD || ""` export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [ "$BRANCH" = master ] + if [ "$BRANCH" = main ] then export BRANCH="" fi diff --git a/build/docker/DendriteJS.Dockerfile b/build/docker/DendriteJS.Dockerfile index e8d742b7e..5e1cffcad 100644 --- a/build/docker/DendriteJS.Dockerfile +++ b/build/docker/DendriteJS.Dockerfile @@ -9,9 +9,9 @@ FROM golang:1.14-alpine AS gobuild # Download and build dendrite WORKDIR /build -ADD https://github.com/matrix-org/dendrite/archive/master.tar.gz /build/master.tar.gz -RUN tar xvfz master.tar.gz -WORKDIR /build/dendrite-master +ADD https://github.com/matrix-org/dendrite/archive/main.tar.gz /build/main.tar.gz +RUN tar xvfz main.tar.gz +WORKDIR /build/dendrite-main RUN GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs @@ -21,7 +21,7 @@ RUN apt-get update && apt-get -y install python # Download riot-web and libp2p repos WORKDIR /build -ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz +ADD https://github.com/matrix-org/go-http-js-libp2p/archive/main.tar.gz /build/libp2p.tar.gz RUN tar xvfz libp2p.tar.gz ADD https://github.com/vector-im/element-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz RUN tar xvfz p2p.tar.gz @@ -31,21 +31,21 @@ WORKDIR /build/element-web-matthew-p2p RUN yarn install RUN ln -s /build/go-http-js-libp2p-master /build/element-web-matthew-p2p/node_modules/go-http-js-libp2p RUN (cd node_modules/go-http-js-libp2p && yarn install) -COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm +COPY --from=gobuild /build/dendrite-main/main.wasm ./src/vector/dendrite.wasm # build it all RUN yarn build:p2p SHELL ["/bin/bash", "-c"] RUN echo $'\ -{ \n\ + { \n\ "default_server_config": { \n\ - "m.homeserver": { \n\ - "base_url": "https://p2p.riot.im", \n\ - "server_name": "p2p.riot.im" \n\ - }, \n\ - "m.identity_server": { \n\ - "base_url": "https://vector.im" \n\ - } \n\ + "m.homeserver": { \n\ + "base_url": "https://p2p.riot.im", \n\ + "server_name": "p2p.riot.im" \n\ + }, \n\ + "m.identity_server": { \n\ + "base_url": "https://vector.im" \n\ + } \n\ }, \n\ "disable_custom_urls": false, \n\ "disable_guests": true, \n\ @@ -55,57 +55,57 @@ RUN echo $'\ "integrations_ui_url": "https://scalar.vector.im/", \n\ "integrations_rest_url": "https://scalar.vector.im/api", \n\ "integrations_widgets_urls": [ \n\ - "https://scalar.vector.im/_matrix/integrations/v1", \n\ - "https://scalar.vector.im/api", \n\ - "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ - "https://scalar-staging.vector.im/api", \n\ - "https://scalar-staging.riot.im/scalar/api" \n\ + "https://scalar.vector.im/_matrix/integrations/v1", \n\ + "https://scalar.vector.im/api", \n\ + "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ + "https://scalar-staging.vector.im/api", \n\ + "https://scalar-staging.riot.im/scalar/api" \n\ ], \n\ "integrations_jitsi_widget_url": "https://scalar.vector.im/api/widgets/jitsi.html", \n\ "bug_report_endpoint_url": "https://riot.im/bugreports/submit", \n\ "defaultCountryCode": "GB", \n\ "showLabsSettings": false, \n\ "features": { \n\ - "feature_pinning": "labs", \n\ - "feature_custom_status": "labs", \n\ - "feature_custom_tags": "labs", \n\ - "feature_state_counters": "labs" \n\ + "feature_pinning": "labs", \n\ + "feature_custom_status": "labs", \n\ + "feature_custom_tags": "labs", \n\ + "feature_state_counters": "labs" \n\ }, \n\ "default_federate": true, \n\ "default_theme": "light", \n\ "roomDirectory": { \n\ - "servers": [ \n\ - "matrix.org" \n\ - ] \n\ + "servers": [ \n\ + "matrix.org" \n\ + ] \n\ }, \n\ "welcomeUserId": "", \n\ "piwik": { \n\ - "url": "https://piwik.riot.im/", \n\ - "whitelistedHSUrls": ["https://matrix.org"], \n\ - "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ - "siteId": 1 \n\ + "url": "https://piwik.riot.im/", \n\ + "whitelistedHSUrls": ["https://matrix.org"], \n\ + "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ + "siteId": 1 \n\ }, \n\ "enable_presence_by_hs_url": { \n\ - "https://matrix.org": false, \n\ - "https://matrix-client.matrix.org": false \n\ + "https://matrix.org": false, \n\ + "https://matrix-client.matrix.org": false \n\ }, \n\ "settingDefaults": { \n\ - "breadcrumbs": true \n\ + "breadcrumbs": true \n\ } \n\ -}' > webapp/config.json + }' > webapp/config.json FROM nginx # Add "Service-Worker-Allowed: /" header so the worker can sniff traffic on this domain rather # than just the path this gets hosted under. NB this newline echo syntax only works on bash. SHELL ["/bin/bash", "-c"] RUN echo $'\ -server { \n\ + server { \n\ listen 80; \n\ add_header \'Service-Worker-Allowed\' \'/\'; \n\ location / { \n\ - root /usr/share/nginx/html; \n\ - index index.html index.htm; \n\ + root /usr/share/nginx/html; \n\ + index index.html index.htm; \n\ } \n\ -}' > /etc/nginx/conf.d/default.conf + }' > /etc/nginx/conf.d/default.conf RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types COPY --from=jsbuild /build/element-web-matthew-p2p/webapp /usr/share/nginx/html diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 1c9c0ac4e..aa8cc6e6e 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -281,10 +281,9 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/%s", m.StorageDirectory, prefix)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 1aae418d1..8b9c88f2a 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -86,9 +86,8 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName()) cfg.Global.PrivateKey = ygg.PrivateKey() cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/", m.StorageDirectory)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 55b381ba5..1d520b4e7 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -2,6 +2,10 @@ FROM golang:1.16-stretch as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build +# we will dump the binaries and config file to this location to ensure any local untracked files +# that come from the COPY . . file don't contaminate the build +RUN mkdir /dendrite + # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. COPY go.mod . @@ -9,13 +13,19 @@ COPY go.sum . RUN go mod download COPY . . -RUN go build ./cmd/dendrite-monolith-server -RUN go build ./cmd/generate-keys -RUN go build ./cmd/generate-config -RUN ./generate-config --ci > dendrite.yaml -RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key +RUN go build -o /dendrite ./cmd/dendrite-monolith-server +RUN go build -o /dendrite ./cmd/generate-keys +RUN go build -o /dendrite ./cmd/generate-config + +WORKDIR /dendrite +RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost EXPOSE 8008 8448 -CMD sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml && ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +# At runtime, generate TLS cert based on the CA now mounted at /ca +# At runtime, replace the SERVER_NAME with what we are told +CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ + ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ + cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile new file mode 100644 index 000000000..60b4d983a --- /dev/null +++ b/build/scripts/ComplementLocal.Dockerfile @@ -0,0 +1,53 @@ +# A local development Complement dockerfile, to be used with host mounts +# /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time. +# /dendrite -> Host-mounted sources +# /runtime -> Binaries and config go here and are run at runtime +# At runtime, dendrite is built from /dendrite and run in /runtime. +# +# Use these mounts to make use of this dockerfile: +# COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro' +FROM golang:1.16-stretch +RUN apt-get update && apt-get install -y sqlite3 + +WORKDIR /runtime + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 + +# This script compiles Dendrite for us. +RUN echo '\ +#!/bin/bash -eux \n\ +if test -f "/runtime/dendrite-monolith-server"; then \n\ + echo "Skipping compilation; binaries exist" \n\ + exit 0 \n\ +fi \n\ +cd /dendrite \n\ +go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ +' > compile.sh && chmod +x compile.sh + +# This script runs Dendrite for us. Must be run in the /runtime directory. +RUN echo '\ +#!/bin/bash -eu \n\ +./generate-keys --private-key matrix_key.pem \n\ +./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ +./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ +cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ +./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ +' > run.sh && chmod +x run.sh + + +WORKDIR /cache +# Pre-download deps; we don't need to do this if the GOPATH is mounted. +COPY go.mod . +COPY go.sum . +RUN go mod download + +# Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed +# up the real compilation. Build the generate-* binaries in the true /runtime locations. +# If the generate-* source is changed, this dockerfile needs re-running. +COPY . . +RUN go build ./cmd/dendrite-monolith-server && go build -o /runtime ./cmd/generate-keys && go build -o /runtime ./cmd/generate-config + + +WORKDIR /runtime +CMD /runtime/compile.sh && /runtime/run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile new file mode 100644 index 000000000..6024ae8da --- /dev/null +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -0,0 +1,53 @@ +FROM golang:1.16-stretch as build +RUN apt-get update && apt-get install -y postgresql +WORKDIR /build + +# No password when connecting over localhost +RUN sed -i "s%127.0.0.1/32 md5%127.0.0.1/32 trust%g" /etc/postgresql/9.6/main/pg_hba.conf && \ + # Bump up max conns for moar concurrency + sed -i 's/max_connections = 100/max_connections = 2000/g' /etc/postgresql/9.6/main/postgresql.conf + +# This entry script starts postgres, waits for it to be up then starts dendrite +RUN echo '\ +#!/bin/bash -eu \n\ +pg_lsclusters \n\ +pg_ctlcluster 9.6 main start \n\ + \n\ +until pg_isready \n\ +do \n\ + echo "Waiting for postgres"; \n\ + sleep 1; \n\ +done \n\ +' > run_postgres.sh && chmod +x run_postgres.sh + +# we will dump the binaries and config file to this location to ensure any local untracked files +# that come from the COPY . . file don't contaminate the build +RUN mkdir /dendrite + +# Utilise Docker caching when downloading dependencies, this stops us needlessly +# downloading dependencies every time. +COPY go.mod . +COPY go.sum . +RUN go mod download + +COPY . . +RUN go build -o /dendrite ./cmd/dendrite-monolith-server +RUN go build -o /dendrite ./cmd/generate-keys +RUN go build -o /dendrite ./cmd/generate-config + +WORKDIR /dendrite +RUN ./generate-keys --private-key matrix_key.pem + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 + + +# At runtime, generate TLS cert based on the CA now mounted at /ca +# At runtime, replace the SERVER_NAME with what we are told +CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ + ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ + # Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns + sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ + sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ + cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \ No newline at end of file diff --git a/build/scripts/find-lint.sh b/build/scripts/find-lint.sh index af87e14d7..e3564ae38 100755 --- a/build/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -33,7 +33,7 @@ echo "Looking for lint..." # Capture exit code to ensure go.{mod,sum} is restored before exiting exit_code=0 -golangci-lint run $args || exit_code=1 +PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 # Restore go.{mod,sum} mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index c850bf91e..575c5377f 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -42,6 +42,7 @@ type DeviceDatabase interface { type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index da0324251..f01e48f80 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -10,4 +10,5 @@ const ( LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" + LoginTypeToken = "m.login.token" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go new file mode 100644 index 000000000..1c14c6fbd --- /dev/null +++ b/clientapi/auth/login.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginFromJSONReader performs authentication given a login request body reader and +// some context. It returns the basic login information and a cleanup function to be +// called after authorization has completed, with the result of the authorization. +// If the final return value is non-nil, an error occurred and the cleanup function +// is nil. +func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { + reqBytes, err := ioutil.ReadAll(r) + if err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var header struct { + Type string `json:"type"` + } + if err := json.Unmarshal(reqBytes, &header); err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var typ Type + switch header.Type { + case authtypes.LoginTypePassword: + typ = &LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + case authtypes.LoginTypeToken: + typ = &LoginTypeToken{ + UserAPI: userAPI, + Config: cfg, + } + default: + err := util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + } + return nil, nil, &err + } + + return typ.LoginFromJSON(ctx, reqBytes) +} + +// UserInternalAPIForLogin contains the aspects of UserAPI required for logging in. +type UserInternalAPIForLogin interface { + uapi.LoginTokenInternalAPI +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go new file mode 100644 index 000000000..e295f8f07 --- /dev/null +++ b/clientapi/auth/login_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "database/sql" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func TestLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantUsername string + WantDeviceID string + WantDeletedTokens []string + }{ + { + Name: "passwordWorks", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "herpassword", + "device_id": "adevice" + }`, + WantUsername: "alice", + WantDeviceID: "adevice", + }, + { + Name: "tokenWorks", + Body: `{ + "type": "m.login.token", + "token": "atoken", + "device_id": "adevice" + }`, + WantUsername: "@auser:example.com", + WantDeviceID: "adevice", + WantDeletedTokens: []string{"atoken"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if err != nil { + t.Fatalf("LoginFromJSONReader failed: %+v", err) + } + cleanup(ctx, &util.JSONResponse{Code: http.StatusOK}) + + if login.Username() != tst.WantUsername { + t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername) + } + + if login.DeviceID == nil { + if tst.WantDeviceID != "" { + t.Errorf("DeviceID: got %v, want %q", login.DeviceID, tst.WantDeviceID) + } + } else { + if *login.DeviceID != tst.WantDeviceID { + t.Errorf("DeviceID: got %q, want %q", *login.DeviceID, tst.WantDeviceID) + } + } + + if !reflect.DeepEqual(userAPI.DeletedTokens, tst.WantDeletedTokens) { + t.Errorf("DeletedTokens: got %+v, want %+v", userAPI.DeletedTokens, tst.WantDeletedTokens) + } + }) + } +} + +func TestBadLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantErrCode string + }{ + {Name: "empty", WantErrCode: "M_BAD_JSON"}, + { + Name: "badUnmarshal", + Body: `badsyntaxJSON`, + WantErrCode: "M_BAD_JSON", + }, + { + Name: "badPassword", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "invalidpassword", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badToken", + Body: `{ + "type": "m.login.token", + "token": "invalidtoken", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badType", + Body: `{ + "type": "m.login.invalid", + "device_id": "adevice" + }`, + WantErrCode: "M_INVALID_ARGUMENT_VALUE", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if errRes == nil { + cleanup(ctx, nil) + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } + }) + } +} + +type fakeAccountDB struct { + AccountDatabase +} + +func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { + if password == "invalidpassword" { + return nil, sql.ErrNoRows + } + + return &uapi.Account{}, nil +} + +type fakeUserInternalAPI struct { + UserInternalAPIForLogin + + DeletedTokens []string +} + +func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { + ua.DeletedTokens = append(ua.DeletedTokens, req.Token) + return nil +} + +func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { + if req.Token == "invalidtoken" { + return nil + } + + res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"} + return nil +} diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go new file mode 100644 index 000000000..845eb5de9 --- /dev/null +++ b/clientapi/auth/login_token.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeToken struct { + UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeToken) Name() string { + return authtypes.LoginTypeToken +} + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + var res uapi.QueryLoginTokenResponse + if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, nil, &jsonErr + } + if res.Data == nil { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("invalid login token"), + } + } + + r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.User = res.Data.UserID + + cleanup := func(ctx context.Context, authRes *util.JSONResponse) { + if authRes == nil { + util.GetLogger(ctx).Error("No JSONResponse provided to LoginTokenType cleanup function") + return + } + if authRes.Code == http.StatusOK { + var res uapi.PerformLoginTokenDeletionResponse + if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed") + } + } + } + return &r.Login, cleanup, nil +} + +// loginTokenRequest struct to hold the possible parameters from an HTTP request. +type loginTokenRequest struct { + Login + Token string `json:"token"` +} diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 7dd21b3f2..18cf94979 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -16,9 +16,12 @@ package auth import ( "context" + "database/sql" "net/http" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -40,17 +43,26 @@ type LoginTypePassword struct { } func (t *LoginTypePassword) Name() string { - return "m.login.password" + return authtypes.LoginTypePassword } -func (t *LoginTypePassword) Request() interface{} { - return &PasswordRequest{} +func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r PasswordRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + login, err := t.Login(ctx, &r) + if err != nil { + return nil, nil, err + } + + return login, func(context.Context, *util.JSONResponse) {}, nil } func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - // Squash username to all lowercase letters - username := strings.ToLower(r.Username()) + username := strings.ToLower(r.Username()) if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -64,8 +76,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.InvalidUsername(err.Error()), } } - _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + // Squash username to all lowercase letters + _, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password) if err != nil { + if err == sql.ErrNoRows { + _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + if err == nil { + return &r.Login, nil + } + } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. return nil, &util.JSONResponse{ diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 30469fc47..9cab7956c 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -32,22 +32,24 @@ import ( type Type interface { // Name returns the name of the auth type e.g `m.login.password` Name() string - // Request returns a pointer to a new request body struct to unmarshal into. - Request() interface{} // Login with the auth type, returning an error response on failure. // Not all types support login, only m.login.password and m.login.token // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login - // `req` is guaranteed to be the type returned from Request() // This function will be called when doing login and when doing 'sudo' style // actions e.g deleting devices. The response must be a 401 as per: // "If the homeserver decides that an attempt on a stage was unsuccessful, but the // client may make a second attempt, it returns the same HTTP status 401 response as above, // with the addition of the standard errcode and error fields describing the error." - Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) + // + // The returned cleanup function must be non-nil on success, and will be called after + // authorization has been completed. Its argument is the final result of authorization. + LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse) // TODO: Extend to support Register() flow // Register(ctx context.Context, sessionID string, req interface{}) } +type LoginCleanupFunc func(context.Context, *util.JSONResponse) + // LoginIdentifier represents identifier types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types type LoginIdentifier struct { @@ -61,11 +63,8 @@ type LoginIdentifier struct { // Login represents the shared fields used in all forms of login/sudo endpoints. type Login struct { - Type string `json:"type"` - Identifier LoginIdentifier `json:"identifier"` - User string `json:"user"` // deprecated in favour of identifier - Medium string `json:"medium"` // deprecated in favour of identifier - Address string `json:"address"` // deprecated in favour of identifier + LoginIdentifier // Flat fields deprecated in favour of `identifier`. + Identifier LoginIdentifier `json:"identifier"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -111,12 +110,11 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: getAccByPass, + GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - // TODO: Add SSO login return &UserInteractive{ Completed: []string{}, Flows: []userInteractiveFlow{ @@ -236,18 +234,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * } } - r := loginType.Request() - if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), - } + login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw)) + if resErr != nil { + return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) } - login, resErr := loginType.Login(ctx, r) - if resErr == nil { - u.AddCompletedStage(sessionID, authType) - // TODO: Check if there's more stages to go and return an error - return login, nil - } - return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) + + u.AddCompletedStage(sessionID, authType) + cleanup(ctx, nil) + // TODO: Check if there's more stages to go and return an error + return login, nil } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 0b7df3545..76d161a74 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,11 @@ var ( } ) -func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { +type fakeAccountDatabase struct { + AccountDatabase +} + +func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { acc, ok := lookup[localpart+" "+plaintextPassword] if !ok { return nil, fmt.Errorf("unknown user/password") @@ -38,7 +42,7 @@ func setup() *UserInteractive { ServerName: serverName, }, } - return NewUserInteractive(getAccountByPassword, cfg) + return NewUserInteractive(&fakeAccountDatabase{}, cfg) } func TestUserInteractiveChallenge(t *testing.T) { diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 7c772125a..a65f3b70d 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -28,7 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -37,7 +37,7 @@ func AddPublicRoutes( router *mux.Router, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, - accountsDB accounts.Database, + accountsDB userdb.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, @@ -49,7 +49,7 @@ func AddPublicRoutes( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, ) { - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 29d7b0b37..b47701368 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon return &resp } + return UnmarshalJSON(body, iface) +} + +func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index caa216e62..97c597030 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -149,6 +149,15 @@ func MissingParam(msg string) *MatrixError { return &MatrixError{"M_MISSING_PARAM", msg} } +// LeaveServerNoticeError is an error returned when trying to reject an invite +// for a server notice room. +func LeaveServerNoticeError() *MatrixError { + return &MatrixError{ + ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM", + Err: "You cannot reject this invite", + } +} + type IncompatibleRoomVersionError struct { RoomVersion string `json:"room_version"` Error string `json:"error"` diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index bd6af5f1f..9b1d6b1a2 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -51,7 +51,7 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string "user_id": userID, "room_id": roomID, "data_type": dataType, - }).Infof("Producing to topic '%s'", p.Topic) + }).Tracef("Producing to topic '%s'", p.Topic) _, err = p.JetStream.PublishMsg(m) return err diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index b448791c3..87bb79366 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -47,8 +47,8 @@ func GetAdminWhois( req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, ) util.JSONResponse { - if userID != device.UserID { - // TODO: Still allow if user is admin + allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID + if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("userID does not match the current user"), diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index e89d8ff24..fcacc76c0 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -30,7 +31,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -137,36 +138,17 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { - // TODO (#267): Check room ID doesn't clash with an existing one, and we - // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) - return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI) -} - -// createRoom implements /createRoom -// nolint: gocyclo -func createRoom( - req *http.Request, device *api.Device, - cfg *config.ClientAPI, roomID string, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, -) util.JSONResponse { - logger := util.GetLogger(req.Context()) - userID := device.UserID var r createRoomRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr } - // TODO: apply rate-limit - if resErr = r.Validate(); resErr != nil { return *resErr } - evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -174,6 +156,25 @@ func createRoom( JSON: jsonerror.InvalidArgumentValue(err.Error()), } } + return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime) +} + +// createRoom implements /createRoom +// nolint: gocyclo +func createRoom( + ctx context.Context, + r createRoomRequest, device *api.Device, + cfg *config.ClientAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, + evTime time.Time, +) util.JSONResponse { + // TODO (#267): Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + + logger := util.GetLogger(ctx) + userID := device.UserID // Clobber keys: creator, room_version @@ -200,16 +201,16 @@ func createRoom( "roomVersion": roomVersion, }).Info("Creating new room") - profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") + util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") return jsonerror.InternalServerError() } createContent := map[string]interface{}{} if len(r.CreationContent) > 0 { if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("invalid create content"), @@ -230,7 +231,7 @@ func createRoom( // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("malformed power_level_content_override"), @@ -319,9 +320,9 @@ func createRoom( } var aliasResp roomserverAPI.GetRoomIDForAliasResponse - err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) + err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") return jsonerror.InternalServerError() } if aliasResp.RoomID != "" { @@ -426,7 +427,7 @@ func createRoom( } err = builder.SetContent(e.Content) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") return jsonerror.InternalServerError() } if i > 0 { @@ -435,12 +436,12 @@ func createRoom( var ev *gomatrixserverlib.Event ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() } if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return jsonerror.InternalServerError() } @@ -448,7 +449,7 @@ func createRoom( builtEvents = append(builtEvents, ev.Headered(roomVersion)) err = authEvents.AddEvent(ev) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") + util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") return jsonerror.InternalServerError() } } @@ -462,8 +463,8 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } @@ -478,9 +479,9 @@ func createRoom( } var aliasResp roomserverAPI.SetRoomAliasResponse - err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) + err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") return jsonerror.InternalServerError() } @@ -519,11 +520,11 @@ func createRoom( for _, invitee := range r.Invite { // Build the invite event. inviteEvent, err := buildMembershipEvent( - req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite, + ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite, roomID, true, cfg, evTime, rsAPI, asAPI, ) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") continue } inviteStrippedState := append( @@ -532,7 +533,7 @@ func createRoom( ) // Send the invite event to the roomserver. err = roomserverAPI.SendInvite( - req.Context(), + ctx, rsAPI, inviteEvent.Headered(roomVersion), inviteStrippedState, // invite room state @@ -544,7 +545,7 @@ func createRoom( return e.JSONResponse() case nil: default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), @@ -556,13 +557,13 @@ func createRoom( if r.Visibility == "public" { // expose this room in the published room list var pubRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: "public", }, &pubRes) if pubRes.Error != nil { // treat as non-fatal since the room is already made by this point - util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public") + util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") } } diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 2e3283be1..0dacfced5 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -63,7 +63,12 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && serverName != cfg.Matrix.ServerName { - res, err := federation.GetPublicRooms(req.Context(), serverName, int(request.Limit), request.Since, false, "") + res, err := federation.GetPublicRoomsFiltered( + req.Context(), serverName, + int(request.Limit), request.Since, + request.Filter.SearchTerms, false, + "", + ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to get public rooms") return jsonerror.InternalServerError() diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 578aaec56..d30a87a57 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7b9d8acd2..7ecab9d4e 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, keyserverAPI api.KeyInternalAPI, device *userapi.Device, - accountDB accounts.Database, cfg *config.ClientAPI, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index 38cef118e..a34dd02d3 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -38,6 +38,12 @@ func LeaveRoomByID( // Ask the roomserver to perform the leave. if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { + if leaveRes.Code != 0 { + return util.JSONResponse{ + Code: leaveRes.Code, + JSON: jsonerror.LeaveServerNoticeError(), + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 589efe0b2..ec5c998be 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,12 +19,11 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -55,7 +54,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, + req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -65,21 +64,14 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, - Config: cfg, - } - r := typePassword.Request() - resErr := httputil.UnmarshalJSONRequest(req, r) - if resErr != nil { - return *resErr - } - login, authErr := typePassword.Login(req.Context(), r) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) if authErr != nil { return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authErr2) + return authErr2 } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 4ce820797..ffe8da136 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -17,6 +17,7 @@ package routing import ( "context" "errors" + "fmt" "net/http" "time" @@ -29,7 +30,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -38,7 +39,7 @@ import ( var errMissingUserID = errors.New("'user_id' must be supplied") func SendBan( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -80,7 +81,7 @@ func SendBan( return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) } -func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, +func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { @@ -124,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us } func SendKick( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -164,7 +165,7 @@ func SendKick( } func SendUnban( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -199,7 +200,7 @@ func SendUnban( } func SendInvite( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -225,27 +226,42 @@ func SendInvite( } } + // We already received the return value, so no need to check for an error here. + response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) + return response +} + +// sendInvite sends an invitation to a user. Returns a JSONResponse and an error +func sendInvite( + ctx context.Context, + accountDB userdb.Database, + device *userapi.Device, + roomID, userID, reason string, + cfg *config.ClientAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time, +) (util.JSONResponse, error) { event, err := buildMembershipEvent( - req.Context(), body.UserID, body.Reason, accountDB, device, "invite", + ctx, userID, reason, accountDB, device, "invite", roomID, false, cfg, evTime, rsAPI, asAPI, ) if err == errMissingUserID { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), - } + }, err } else if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), - } + }, err } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") + return jsonerror.InternalServerError(), err } err = roomserverAPI.SendInvite( - req.Context(), rsAPI, + ctx, rsAPI, event, nil, // ask the roomserver to draw up invite room state for us cfg.Matrix.ServerName, @@ -253,24 +269,24 @@ func SendInvite( ) switch e := err.(type) { case *roomserverAPI.PerformError: - return e.JSONResponse() + return e.JSONResponse(), err case nil: return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, - } + }, nil default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), - } + }, err } } func buildMembershipEvent( ctx context.Context, - targetUserID, reason string, accountDB accounts.Database, + targetUserID, reason string, accountDB userdb.Database, device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, @@ -311,7 +327,7 @@ func loadProfile( ctx context.Context, userID string, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) @@ -365,7 +381,7 @@ func checkAndProcessThreepid( body *threepid.MembershipRequest, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { @@ -459,13 +475,7 @@ func SendForget( if membershipRes.IsInRoom { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden("user is still a member of the room"), - } - } - if !membershipRes.HasBeenInRoom { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden("user did not belong to room"), + JSON: jsonerror.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), } } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index b24424430..499510193 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -9,7 +9,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -29,7 +29,7 @@ type newPasswordAuth struct { func Password( req *http.Request, userAPI api.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 26aa64ce1..8f89e97f4 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -19,7 +19,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to @@ -82,7 +82,7 @@ func UnpeekRoomByID( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, ) util.JSONResponse { unpeekReq := roomserverAPI.PerformUnpeekRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 017facd20..717cbda75 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -36,7 +36,7 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, @@ -65,7 +65,7 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -92,7 +92,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -182,7 +182,7 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -209,7 +209,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -302,7 +302,7 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, + ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 8823a41e3..d00d9886e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -32,18 +32,19 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/tokens" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -447,7 +448,7 @@ func validateApplicationService( func Register( req *http.Request, userAPI userapi.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { var r registerRequest @@ -531,6 +532,13 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.UserInternalAPI, ) util.JSONResponse { + if cfg.RegistrationDisabled || cfg.GuestsDisabled { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Guest registration is disabled"), + } + } + var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, @@ -701,7 +709,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -720,7 +728,7 @@ func checkAndCompleteFlow( // This flow was completed, registration can continue return completeRegistration( req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } @@ -745,6 +753,7 @@ func completeRegistration( username, password, appserviceID, ipAddr, userAgent string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, + accType userapi.AccountType, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -759,13 +768,12 @@ func completeRegistration( JSON: jsonerror.BadJSON("missing password"), } } - var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, Password: password, - AccountType: userapi.AccountTypeUser, + AccountType: accType, OnConflict: userapi.ConflictAbort, }, &accRes) if err != nil { @@ -891,7 +899,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") @@ -963,5 +971,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS return *resErr } deviceID := "shared_secret_registration" - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID) + + accType := userapi.AccountTypeUser + if ssrr.Admin { + accType = userapi.AccountTypeAdmin + } + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9263c66bb..d75f58b81 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "net/http" "strings" @@ -34,7 +35,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -51,7 +52,7 @@ func Setup( eduAPI eduServerAPI.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, @@ -62,7 +63,7 @@ func Setup( mscCfg *config.MSCs, ) { rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) - userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) + userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, @@ -117,15 +118,66 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } - r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() + // server notifications + if cfg.Matrix.ServerNotices.Enabled { + logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") + serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + + synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", + httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + &txnID, transactionsCache, + ) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + synapseAdminRouter.Handle("/admin/v1/send_server_notice", + httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + nil, transactionsCache, + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + } + + // You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables. + // So make a named path variable called 'apiversion' (which we will never read in handlers) and then do + // (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group + // from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group + // using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching. + // Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing! + v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() + unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() - r0mux.Handle("/createRoom", + v3mux.Handle("/createRoom", httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/join/{roomIDOrAlias}", + v3mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -141,7 +193,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { - r0mux.Handle("/peek/{roomIDOrAlias}", + v3mux.Handle("/peek/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -156,12 +208,12 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) } - r0mux.Handle("/joined_rooms", + v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/join", + v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -175,7 +227,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/leave", + v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -189,7 +241,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unpeek", + v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -200,7 +252,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/ban", + v3mux.Handle("/rooms/{roomID}/ban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -209,7 +261,7 @@ func Setup( return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/invite", + v3mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -221,7 +273,7 @@ func Setup( return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/kick", + v3mux.Handle("/rooms/{roomID}/kick", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -230,7 +282,7 @@ func Setup( return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unban", + v3mux.Handle("/rooms/{roomID}/unban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -239,7 +291,7 @@ func Setup( return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -248,7 +300,7 @@ func Setup( return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -259,7 +311,7 @@ func Setup( nil, cfg, rsAPI, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/event/{eventID}", + v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -269,7 +321,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -277,7 +329,7 @@ func Setup( return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -285,7 +337,7 @@ func Setup( return GetAliases(req, rsAPI, device, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -296,7 +348,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -305,7 +357,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", + v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -317,7 +369,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", + v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -328,21 +380,21 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -352,7 +404,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -362,7 +414,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -371,7 +423,7 @@ func Setup( return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -381,7 +433,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) // TODO: Add AS support - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -390,25 +442,25 @@ func Setup( return SetVisibility(req, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/publicRooms", + v3mux.Handle("/publicRooms", httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout", + v3mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout/all", + v3mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/typing/{userID}", + v3mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -420,7 +472,7 @@ func Setup( return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}", httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -429,7 +481,7 @@ func Setup( return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -439,7 +491,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + v3mux.Handle("/sendToDevice/{eventType}/{txnID}", httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -464,7 +516,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/account/whoami", + v3mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -473,7 +525,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/password", + v3mux.Handle("/account/password", httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -482,7 +534,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/account/deactivate", + v3mux.Handle("/account/deactivate", httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -493,7 +545,7 @@ func Setup( // Stub endpoints required by Element - r0mux.Handle("/login", + v3mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -502,14 +554,14 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/auth/{authType}/fallback/web", + v3mux.Handle("/auth/{authType}/fallback/web", httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/pushrules/", + v3mux.Handle("/pushrules/", httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API res := json.RawMessage(`{ @@ -530,7 +582,7 @@ func Setup( // Element user settings - r0mux.Handle("/profile/{userID}", + v3mux.Handle("/profile/{userID}", httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -540,7 +592,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -550,7 +602,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -565,7 +617,7 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -575,7 +627,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -590,13 +642,13 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, accountDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) }), @@ -608,14 +660,14 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", + v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { return RequestEmailToken(req, accountDB, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) // Element logs get flooded unless this is handled - r0mux.Handle("/presence/{userID}/status", + v3mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -628,7 +680,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/voip/turnServer", + v3mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -637,7 +689,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/thirdparty/protocols", + v3mux.Handle("/thirdparty/protocols", httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { // TODO: Return the third party protcols return util.JSONResponse{ @@ -647,7 +699,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/initialSync", + v3mux.Handle("/rooms/{roomID}/initialSync", httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { // TODO: Allow people to peek into rooms. return util.JSONResponse{ @@ -657,7 +709,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -667,7 +719,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -677,7 +729,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -687,7 +739,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -697,7 +749,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/admin/whois/{userID}", + v3mux.Handle("/admin/whois/{userID}", httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -707,7 +759,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/openid/request_token", + v3mux.Handle("/user/{userID}/openid/request_token", httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -720,7 +772,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user_directory/search", + v3mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -745,7 +797,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/members", + v3mux.Handle("/rooms/{roomID}/members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -755,7 +807,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/joined_members", + v3mux.Handle("/rooms/{roomID}/joined_members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -765,7 +817,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/read_markers", + v3mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -778,7 +830,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/forget", + v3mux.Handle("/rooms/{roomID}/forget", httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -791,13 +843,13 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/devices", + v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -807,7 +859,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -817,7 +869,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -827,14 +879,14 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/delete_devices", + v3mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return DeleteDevices(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) // Stub implementations for sytest - r0mux.Handle("/events", + v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, @@ -844,7 +896,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/initialSync", + v3mux.Handle("/initialSync", httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", @@ -852,7 +904,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -862,7 +914,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -872,7 +924,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -882,7 +934,7 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/capabilities", + v3mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -925,11 +977,11 @@ func Setup( return CreateKeyBackupVersion(req, userAPI, device) }) - r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) - r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) - r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) @@ -1021,9 +1073,9 @@ func Setup( return UploadBackupKeys(req, userAPI, device, version, &keyReq) }) - r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) @@ -1051,9 +1103,9 @@ func Setup( return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) }) - r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) @@ -1071,34 +1123,34 @@ func Setup( return UploadCrossSigningDeviceSignatures(req, keyAPI, device) }) - r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) // Supplying a device ID is deprecated. - r0mux.Handle("/keys/upload/{deviceID}", + v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/upload", + v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/query", + v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/claim", + v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", + v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 606107b9f..23935b5d9 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -15,10 +15,16 @@ package routing import ( + "context" "net/http" "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,10 +32,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid @@ -97,7 +99,22 @@ func SendEvent( defer mutex.(*sync.Mutex).Unlock() startedGeneratingEvent := time.Now() - e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) + + var r map[string]interface{} // must be a JSON object + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + evTime, err := httputil.ParseTSParam(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } + + e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime) if resErr != nil { return *resErr } @@ -153,27 +170,16 @@ func SendEvent( } func generateSendEvent( - req *http.Request, + ctx context.Context, + r map[string]interface{}, device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, + evTime time.Time, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request userID := device.UserID - var r map[string]interface{} // must be a JSON object - resErr := httputil.UnmarshalJSONRequest(req, &r) - if resErr != nil { - return nil, resErr - } - - evTime, err := httputil.ParseTSParam(req) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), - } - } // create the new event and set all the fields we can builder := gomatrixserverlib.EventBuilder{ @@ -182,15 +188,15 @@ func generateSendEvent( Type: eventType, StateKey: stateKey, } - err = builder.SetContent(r) + err := builder.SetContent(r) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, @@ -213,7 +219,7 @@ func generateSendEvent( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 3abf3db27..fd214b34b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *userapi.Device, roomID string, - userID string, accountDB accounts.Database, + userID string, accountDB userdb.Database, eduAPI api.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go new file mode 100644 index 000000000..42a303a6b --- /dev/null +++ b/clientapi/routing/server_notices.go @@ -0,0 +1,343 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + userdb "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +// Unspecced server notice request +// https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/server_notices.md +type sendServerNoticeRequest struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` +} + +// SendServerNotice sends a message to a specific user. It can only be invoked by an admin. +func SendServerNotice( + req *http.Request, + cfgNotices *config.ServerNotices, + cfgClient *config.ClientAPI, + userAPI userapi.UserInternalAPI, + rsAPI api.RoomserverInternalAPI, + accountsDB userdb.Database, + asAPI appserviceAPI.AppServiceQueryAPI, + device *userapi.Device, + senderDevice *userapi.Device, + txnID *string, + txnCache *transactions.Cache, +) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + + if txnID != nil { + // Try to fetch response from transactionsCache + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + ctx := req.Context() + var r sendServerNoticeRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // check that all required fields are set + if !r.valid() { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid request"), + } + } + + // get rooms for specified user + allUserRooms := []string{} + userRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "join", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get invites for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "invite", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get left rooms for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "leave", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + + // get rooms of the sender + senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) + senderRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: senderUserID, + WantMembership: "join", + }, &senderRooms); err != nil { + return util.ErrorResponse(err) + } + + // check if we have rooms in common + commonRooms := []string{} + for _, userRoomID := range allUserRooms { + for _, senderRoomID := range senderRooms.RoomIDs { + if userRoomID == senderRoomID { + commonRooms = append(commonRooms, senderRoomID) + } + } + } + + if len(commonRooms) > 1 { + return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms))) + } + + var ( + roomID string + roomVersion = gomatrixserverlib.RoomVersionV6 + ) + + // create a new room for the user + if len(commonRooms) == 0 { + powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) + powerLevelContent.Users[r.UserID] = -10 // taken from Synapse + pl, err := json.Marshal(powerLevelContent) + if err != nil { + return util.ErrorResponse(err) + } + createContent := map[string]interface{}{} + createContent["m.federate"] = false + cc, err := json.Marshal(createContent) + if err != nil { + return util.ErrorResponse(err) + } + crReq := createRoomRequest{ + Invite: []string{r.UserID}, + Name: cfgNotices.RoomName, + Visibility: "private", + Preset: presetPrivateChat, + CreationContent: cc, + GuestCanJoin: false, + RoomVersion: roomVersion, + PowerLevelContentOverride: pl, + } + + roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now()) + + switch data := roomRes.JSON.(type) { + case createRoomResponse: + roomID = data.RoomID + + // tag the room, so we can later check if the user tries to reject an invite + serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{ + "m.server_notice": { + Order: 1.0, + }, + }} + if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { + util.GetLogger(ctx).WithError(err).Error("saveTagData failed") + return jsonerror.InternalServerError() + } + + default: + // if we didn't get a createRoomResponse, we probably received an error, so return that. + return roomRes + } + + } else { + // we've found a room in common, check the membership + roomID = commonRooms[0] + // re-invite the user + res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + if err != nil { + return res + } + } + + startedGeneratingEvent := time.Now() + + request := map[string]interface{}{ + "body": r.Content.Body, + "msgtype": r.Content.MsgType, + } + e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) + if resErr != nil { + logrus.Errorf("failed to send message: %+v", resErr) + return *resErr + } + timeToGenerateEvent := time.Since(startedGeneratingEvent) + + var txnAndSessionID *api.TransactionID + if txnID != nil { + txnAndSessionID = &api.TransactionID{ + TransactionID: *txnID, + SessionID: device.SessionID, + } + } + + // pass the new event to the roomserver and receive the correct event ID + // event ID in case of duplicate transaction is discarded + startedSubmittingEvent := time.Now() + if err := api.SendEvents( + ctx, rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + e.Headered(roomVersion), + }, + cfgClient.Matrix.ServerName, + cfgClient.Matrix.ServerName, + txnAndSessionID, + false, + ); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError() + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": e.EventID(), + "room_id": roomID, + "room_version": roomVersion, + }).Info("Sent event to roomserver") + timeToSubmitEvent := time.Since(startedSubmittingEvent) + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: sendEventResponse{e.EventID()}, + } + // Add response to transactionsCache + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + // Take a note of how long it took to generate the event vs submit + // it to the roomserver. + sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds())) + sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds())) + + return res +} + +func (r sendServerNoticeRequest) valid() (ok bool) { + if r.UserID == "" { + return false + } + if r.Content.MsgType == "" || r.Content.Body == "" { + return false + } + return true +} + +// getSenderDevice creates a user account to be used when sending server notices. +// It returns an userapi.Device, which is used for building the event +func getSenderDevice( + ctx context.Context, + userAPI userapi.UserInternalAPI, + accountDB userdb.Database, + cfg *config.ClientAPI, +) (*userapi.Device, error) { + var accRes userapi.PerformAccountCreationResponse + // create account if it doesn't exist + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + OnConflict: userapi.ConflictUpdate, + }, &accRes) + if err != nil { + return nil, err + } + + // set the avatarurl for the user + if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil { + util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed") + return nil, err + } + + // Check if we got existing devices + deviceRes := &userapi.QueryDevicesResponse{} + err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ + UserID: accRes.Account.UserID, + }, deviceRes) + if err != nil { + return nil, err + } + + if len(deviceRes.Devices) > 0 { + return &deviceRes.Devices[0], nil + } + + // create an AccessToken + token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ + ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), + ServerName: string(cfg.Matrix.ServerName), + UserID: accRes.Account.UserID, + }) + if err != nil { + return nil, err + } + + // create a new device, if we didn't find any + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, + AccessToken: token, + NoDeviceListUpdate: true, + }, &devRes) + + if err != nil { + return nil, err + } + return devRes.Device, nil +} diff --git a/clientapi/routing/server_notices_test.go b/clientapi/routing/server_notices_test.go new file mode 100644 index 000000000..2fac072cd --- /dev/null +++ b/clientapi/routing/server_notices_test.go @@ -0,0 +1,83 @@ +package routing + +import ( + "testing" +) + +func Test_sendServerNoticeRequest_validate(t *testing.T) { + type fields struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + } + + content := struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + MsgType: "m.text", + Body: "Hello world!", + } + + tests := []struct { + name string + fields fields + wantOk bool + }{ + { + name: "empty request", + fields: fields{}, + }, + { + name: "msgtype empty", + fields: fields{ + UserID: "@alice:localhost", + Content: struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + Body: "Hello world!", + }, + }, + }, + { + name: "msg body empty", + fields: fields{ + UserID: "@alice:localhost", + }, + }, + { + name: "statekey empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + { + name: "type empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := sendServerNoticeRequest{ + UserID: tt.fields.UserID, + Content: tt.fields.Content, + Type: tt.fields.Type, + StateKey: tt.fields.StateKey, + } + if gotOk := r.valid(); gotOk != tt.wantOk { + t.Errorf("valid() = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index f4d233798..d89b62953 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,7 +40,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf Code: http.StatusBadRequest, JSON: jsonerror.MatrixError{ ErrCode: "M_THREEPID_IN_USE", - Err: accounts.Err3PIDInUse.Error(), + Err: userdb.Err3PIDInUse.Error(), }, } } @@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -170,7 +170,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { +func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/routing/whoami.go b/clientapi/routing/whoami.go index 26280f6cc..a1d9d6675 100644 --- a/clientapi/routing/whoami.go +++ b/clientapi/routing/whoami.go @@ -21,7 +21,9 @@ import ( // whoamiResponse represents an response for a `whoami` request type whoamiResponse struct { - UserID string `json:"user_id"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + IsGuest bool `json:"is_guest"` } // Whoami implements `/account/whoami` which enables client to query their account user id. @@ -29,6 +31,10 @@ type whoamiResponse struct { func Whoami(req *http.Request, device *api.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, - JSON: whoamiResponse{UserID: device.UserID}, + JSON: whoamiResponse{ + UserID: device.UserID, + DeviceID: device.ID, + IsGuest: device.AccountType == api.AccountTypeGuest, + }, } } diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index db62ce060..9d9a2ba7a 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,7 +87,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, db accounts.Database, + rsAPI api.RoomserverInternalAPI, db userdb.Database, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3ac077705..3003896c8 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -23,12 +23,14 @@ import ( "os" "strings" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "golang.org/x/term" + + "github.com/matrix-org/dendrite/setup" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) const usage = `Usage: %s @@ -57,6 +59,7 @@ var ( pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") ) func main() { @@ -74,14 +77,23 @@ func main() { pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) + accountDB, err := userdb.NewDatabase( + &config.DatabaseOptions{ + ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, + }, + cfg.Global.ServerName, bcrypt.DefaultCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + api.DefaultLoginTokenLifetime, + ) if err != nil { logrus.Fatalln("Failed to connect to the database:", err.Error()) } - _, err = accountDB.CreateAccount(context.Background(), *username, pass, "") + accType := api.AccountTypeUser + if *isAdmin { + accType = api.AccountTypeAdmin + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 7cbd0b6d4..78536901c 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -126,7 +126,6 @@ func main() { cfg.FederationAPI.FederationMaxRetries = 6 cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a897dcd1a..5810a7f18 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -160,7 +160,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-yggdrasil/README.md b/cmd/dendrite-demo-yggdrasil/README.md index c471cef22..946333576 100644 --- a/cmd/dendrite-demo-yggdrasil/README.md +++ b/cmd/dendrite-demo-yggdrasil/README.md @@ -1,6 +1,6 @@ # Yggdrasil Demo -This is the Dendrite Yggdrasil demo! It's easy to get started - all you need is Go 1.15 or later. +This is the Dendrite Yggdrasil demo! It's easy to get started - all you need is Go 1.16 or later. To run the homeserver, start at the root of the Dendrite repository and run: diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index d34e5159d..d16f0e9e5 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -64,14 +64,6 @@ func main() { if err != nil { panic(err) } - /* - ygg.SetMulticastEnabled(true) - if instancePeer != nil && *instancePeer != "" { - if err = ygg.SetStaticPeer(*instancePeer); err != nil { - logrus.WithError(err).Error("Failed to set static peer") - } - } - */ // iterate through the cli args and check if the config flag was set configFlagSet := false @@ -89,16 +81,14 @@ func main() { cfg = setup.ParseFlags(true) } else { cfg.Defaults(true) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) if err = cfg.Derive(); err != nil { diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 4d0598f3f..bb2685208 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -132,6 +132,7 @@ func main() { // dependency. Other components also need updating after their dependencies are up. rsImpl.SetFederationAPI(fsAPI, keyRing) rsImpl.SetAppserviceAPI(asAPI) + rsImpl.SetUserAPI(userAPI) keyImpl.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 62eea78f2..664f644f3 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -164,7 +164,6 @@ func startup() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 59de07cd0..0ea41b4c4 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -167,7 +167,6 @@ func main() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index a79470d83..ba5a87a7a 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -32,7 +32,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI) } cfg.Global.TrustedIDServers = []string{ "matrix.org", @@ -83,7 +82,7 @@ func main() { if *defaultsForCI { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false - cfg.FederationAPI.DisableTLSValidation = true + cfg.FederationAPI.DisableTLSValidation = false // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MSCs.MSCs = []string{"msc2836", "msc2946", "msc2444", "msc2753"} @@ -91,6 +90,7 @@ func main() { cfg.Logging[0].Type = "std" cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.Global.JetStream.InMemory = true + cfg.ClientAPI.RegistrationSharedSecret = "complement" } j, err := yaml.Marshal(cfg) diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index 743109f13..bddf219dc 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -32,9 +32,12 @@ Arguments: ` var ( - tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") - tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") - privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") + tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") + privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.") ) func main() { @@ -54,8 +57,15 @@ func main() { if *tlsCertFile == "" || *tlsKeyFile == "" { log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied") } - if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { - panic(err) + if *authorityCertFile == "" && *authorityKeyFile == "" { + if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { + panic(err) + } + } else { + // generate the TLS cert/key based on the authority given. + if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil { + panic(err) + } } fmt.Printf("Created TLS cert file: %s\n", *tlsCertFile) fmt.Printf("Created TLS key file: %s\n", *tlsKeyFile) diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 8ed5cbd91..31a5b0050 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -8,12 +8,11 @@ import ( "log" "os" - pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" "github.com/pressly/goose" + pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -26,8 +25,7 @@ const ( RoomServer = "roomserver" SigningKeyServer = "signingkeyserver" SyncAPI = "syncapi" - UserAPIAccounts = "userapi_accounts" - UserAPIDevices = "userapi_devices" + UserAPI = "userapi" ) var ( @@ -35,7 +33,7 @@ var ( flags = flag.NewFlagSet("goose", flag.ExitOnError) component = flags.String("component", "", "dendrite component name") knownDBs = []string{ - AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, + AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI, } ) @@ -143,18 +141,14 @@ Commands: func loadSQLiteDeltas(component string) { switch component { - case UserAPIAccounts: - slaccounts.LoadFromGoose() - case UserAPIDevices: - sldevices.LoadFromGoose() + case UserAPI: + slusers.LoadFromGoose() } } func loadPostgresDeltas(component string) { switch component { - case UserAPIAccounts: - pgaccounts.LoadFromGoose() - case UserAPIDevices: - pgdevices.LoadFromGoose() + case UserAPI: + pgusers.LoadFromGoose() } } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 38b146d70..6d086ed77 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -68,6 +68,18 @@ global: # to other servers and the federation API will not be exposed. disable_federation: false + # Server notices allows server admins to send messages to all users. + server_notices: + enabled: false + # The server localpart to be used when sending notices, ensure this is not yet taken + local_part: "_server" + # The displayname to be used when sending notices + display_name: "Server alerts" + # The mxid of the avatar to use + avatar_url: "" + # The roomname to be used when creating messages + room_name: "Server Alerts" + # Configuration for NATS JetStream jetstream: # A list of NATS Server addresses to connect to. If none are specified, an @@ -142,6 +154,10 @@ client_api: # using the registration shared secret below. registration_disabled: false + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + # If set, allows registration by anyone who knows the shared secret, regardless of # whether registration is otherwise disabled. registration_shared_secret: "" @@ -204,13 +220,6 @@ federation_api: # enable this option in production as it presents a security risk! disable_tls_validation: false - # Use the following proxy server for outbound federation traffic. - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 - # Perspective keyservers to use as a backup when direct key fetches fail. This may # be required to satisfy key requests for servers that are no longer online when # joining some rooms. diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index ea4b2b27d..fe7127c76 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -37,7 +37,7 @@ If a job fails, click the "details" button and you should be taken to the job's logs. ![Click the details button on the failing build -step](https://raw.githubusercontent.com/matrix-org/dendrite/master/docs/images/details-button-location.jpg) +step](https://raw.githubusercontent.com/matrix-org/dendrite/main/docs/images/details-button-location.jpg) Scroll down to the failing step and you should see some log output. Scan the logs until you find what it's complaining about, fix it, submit a new commit, @@ -57,7 +57,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md#using-a-sytest-docker-image) +[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. @@ -94,4 +94,4 @@ For more general questions please use We ask that everyone who contributes to the project signs off their contributions, in accordance with the -[DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off). +[DCO](https://github.com/matrix-org/matrix-doc/blob/main/CONTRIBUTING.rst#sign-off). diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 2afb43c6a..686ae1dbb 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -27,7 +27,7 @@ use in production environments just yet! Dendrite requires: -* Go 1.15 or higher +* Go 1.16 or higher * PostgreSQL 12 or higher (if using PostgreSQL databases, not needed for SQLite) If you want to run a polylith deployment, you also need: diff --git a/docs/p2p.md b/docs/p2p.md index e858ba114..4e9a50524 100644 --- a/docs/p2p.md +++ b/docs/p2p.md @@ -6,7 +6,7 @@ These are the instructions for setting up P2P Dendrite, current as of May 2020. #### Build -- The `master` branch has a WASM-only binary for dendrite: `./cmd/dendritejs`. +- The `main` branch has a WASM-only binary for dendrite: `./cmd/dendritejs`. - Build it and copy assets to riot-web. ``` $ ./build-dendritejs.sh diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 2fa253f4d..2aab107b2 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -100,10 +100,4 @@ type EDUServerInputAPI interface { request *InputReceiptEventRequest, response *InputReceiptEventResponse, ) error - - InputCrossSigningKeyUpdate( - ctx context.Context, - request *InputCrossSigningKeyUpdateRequest, - response *InputCrossSigningKeyUpdateResponse, - ) error } diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index db03001ba..9b7e21651 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -42,7 +42,7 @@ func NewInternalAPI( ) api.EDUServerInputAPI { cfg := &base.Cfg.EDUServer - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return &input.EDUServerInputAPI{ Cache: eduCache, @@ -51,7 +51,6 @@ func NewInternalAPI( OutputTypingEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputTypingEvent), OutputSendToDeviceEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputSendToDeviceEvent), OutputReceiptEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), - OutputKeyChangeEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), ServerName: cfg.Matrix.ServerName, } } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index e7501a907..e58f0dd34 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" - keyapi "github.com/matrix-org/dendrite/keyserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -40,8 +39,6 @@ type EDUServerInputAPI struct { OutputSendToDeviceEventTopic string // The kafka topic to output new receipt events to OutputReceiptEventTopic string - // The kafka topic to output new key change events to - OutputKeyChangeEventTopic string // kafka producer JetStream nats.JetStreamContext // Internal user query API @@ -80,34 +77,6 @@ func (t *EDUServerInputAPI) InputSendToDeviceEvent( return t.sendToDeviceEvent(ise) } -// InputCrossSigningKeyUpdate implements api.EDUServerInputAPI -func (t *EDUServerInputAPI) InputCrossSigningKeyUpdate( - ctx context.Context, - request *api.InputCrossSigningKeyUpdateRequest, - response *api.InputCrossSigningKeyUpdateResponse, -) error { - eventJSON, err := json.Marshal(&keyapi.DeviceMessage{ - Type: keyapi.TypeCrossSigningUpdate, - OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ - CrossSigningKeyUpdate: request.CrossSigningKeyUpdate, - }, - }) - if err != nil { - return err - } - - logrus.WithFields(logrus.Fields{ - "user_id": request.UserID, - }).Infof("Producing to topic '%s'", t.OutputKeyChangeEventTopic) - - _, err = t.JetStream.PublishMsg(&nats.Msg{ - Subject: t.OutputKeyChangeEventTopic, - Header: nats.Header{}, - Data: eventJSON, - }) - return err -} - func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, @@ -134,7 +103,7 @@ func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { "room_id": ite.RoomID, "user_id": ite.UserID, "typing": ite.Typing, - }).Infof("Producing to topic '%s'", t.OutputTypingEventTopic) + }).Tracef("Producing to topic '%s'", t.OutputTypingEventTopic) _, err = t.JetStream.PublishMsg(&nats.Msg{ Subject: t.OutputTypingEventTopic, @@ -175,7 +144,7 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e "user_id": ise.UserID, "num_devices": len(devices), "type": ise.Type, - }).Infof("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) + }).Tracef("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) for _, device := range devices { ote := &api.OutputSendToDeviceEvent{ UserID: ise.UserID, @@ -208,7 +177,7 @@ func (t *EDUServerInputAPI) InputReceiptEvent( request *api.InputReceiptEventRequest, response *api.InputReceiptEventResponse, ) error { - logrus.WithFields(logrus.Fields{}).Infof("Producing to topic '%s'", t.OutputReceiptEventTopic) + logrus.WithFields(logrus.Fields{}).Tracef("Producing to topic '%s'", t.OutputReceiptEventTopic) output := &api.OutputReceiptEvent{ UserID: request.InputReceiptEvent.UserID, RoomID: request.InputReceiptEvent.RoomID, diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go index 9a6f483c2..0690ed827 100644 --- a/eduserver/inthttp/client.go +++ b/eduserver/inthttp/client.go @@ -12,10 +12,9 @@ import ( // HTTP paths for the internal HTTP APIs const ( - EDUServerInputTypingEventPath = "/eduserver/input" - EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" - EDUServerInputReceiptEventPath = "/eduserver/receipt" - EDUServerInputCrossSigningKeyUpdatePath = "/eduserver/crossSigningKeyUpdate" + EDUServerInputTypingEventPath = "/eduserver/input" + EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" + EDUServerInputReceiptEventPath = "/eduserver/receipt" ) // NewEDUServerClient creates a EDUServerInputAPI implemented by talking to a HTTP POST API. @@ -69,16 +68,3 @@ func (h *httpEDUServerInputAPI) InputReceiptEvent( apiURL := h.eduServerURL + EDUServerInputReceiptEventPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } - -// InputCrossSigningKeyUpdate implements EDUServerInputAPI -func (h *httpEDUServerInputAPI) InputCrossSigningKeyUpdate( - ctx context.Context, - request *api.InputCrossSigningKeyUpdateRequest, - response *api.InputCrossSigningKeyUpdateResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputCrossSigningKeyUpdate") - defer span.Finish() - - apiURL := h.eduServerURL + EDUServerInputCrossSigningKeyUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go index a50ca84f9..a34943750 100644 --- a/eduserver/inthttp/server.go +++ b/eduserver/inthttp/server.go @@ -51,17 +51,4 @@ func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(EDUServerInputCrossSigningKeyUpdatePath, - httputil.MakeInternalAPI("inputCrossSigningKeyUpdate", func(req *http.Request) util.JSONResponse { - var request api.InputCrossSigningKeyUpdateRequest - var response api.InputCrossSigningKeyUpdateResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := t.InputCrossSigningKeyUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) } diff --git a/federationapi/consumers/eduserver.go b/federationapi/consumers/eduserver.go index c3e5b4d49..1f81fa258 100644 --- a/federationapi/consumers/eduserver.go +++ b/federationapi/consumers/eduserver.go @@ -34,7 +34,7 @@ import ( type OutputEDUConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -66,13 +66,22 @@ func NewOutputEDUConsumer( // Start consuming from EDU servers func (t *OutputEDUConsumer) Start() error { - if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } return nil @@ -80,175 +89,169 @@ func (t *OutputEDUConsumer) Start() error { // onSendToDeviceEvent is called in response to a message received on the // send-to-device events topic from the EDU server. -func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { +func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool { // Extract the send-to-device event from msg. - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var ote api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") - return true - } - - // only send send-to-device events which originated from us - _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) - if err != nil { - log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") - return true - } - if originServerName != t.ServerName { - log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") - return true - } - - _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") - return true - } - - // Pack the EDU and marshal it - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), - } - tdm := gomatrixserverlib.ToDeviceMessage{ - Sender: ote.Sender, - Type: ote.Type, - MessageID: util.RandomString(32), - Messages: map[string]map[string]json.RawMessage{ - ote.UserID: { - ote.DeviceID: ote.Content, - }, - }, - } - if edu.Content, err = json.Marshal(tdm); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - log.Infof("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - + var ote api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") return true - }) + } + + // only send send-to-device events which originated from us + _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) + if err != nil { + log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") + return true + } + if originServerName != t.ServerName { + log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") + return true + } + + _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") + return true + } + + // Pack the EDU and marshal it + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MDirectToDevice, + Origin: string(t.ServerName), + } + tdm := gomatrixserverlib.ToDeviceMessage{ + Sender: ote.Sender, + Type: ote.Type, + MessageID: util.RandomString(32), + Messages: map[string]map[string]json.RawMessage{ + ote.UserID: { + ote.DeviceID: ote.Content, + }, + }, + } + if edu.Content, err = json.Marshal(tdm); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + log.Debugf("Sending send-to-device message into %q destination queue", destServerName) + if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onTypingEvent is called in response to a message received on the typing // events topic from the EDU server. -func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var ote api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") - _ = msg.Ack() - return true - } - - // only send typing events which originated from us - _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") - _ = msg.Ack() - return true - } - if typingServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} - if edu.Content, err = json.Marshal(map[string]interface{}{ - "room_id": ote.Event.RoomID, - "user_id": ote.Event.UserID, - "typing": ote.Event.Typing, - }); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var ote api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") + _ = msg.Ack() return true - }) + } + + // only send typing events which originated from us + _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") + _ = msg.Ack() + return true + } + if typingServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": ote.Event.RoomID, + "user_id": ote.Event.UserID, + "typing": ote.Event.Typing, + }); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onReceiptEvent is called in response to a message received on the receipt // events topic from the EDU server. -func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var receipt api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &receipt); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") - return true - } - - // only send receipt events which originated from us - _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) - if err != nil { - log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") - return true - } - if receiptServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - content := map[string]api.FederationReceiptMRead{} - content[receipt.RoomID] = api.FederationReceiptMRead{ - User: map[string]api.FederationReceiptData{ - receipt.UserID: { - Data: api.ReceiptTS{ - TS: receipt.Timestamp, - }, - EventIDs: []string{receipt.EventID}, - }, - }, - } - - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), - } - if edu.Content, err = json.Marshal(content); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") return true - }) + } + + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") + return true + } + if receiptServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, + }, + }, + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6a737d0ad..22dbc32da 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -17,80 +17,73 @@ package consumers import ( "context" "encoding/json" - "fmt" - "github.com/Shopify/sarama" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { ctx context.Context - consumer *internal.ContinualConsumer + jetstream nats.JetStreamContext + durable string db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName rsAPI roomserverAPI.RoomserverInternalAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. func NewKeyChangeConsumer( process *process.ProcessContext, cfg *config.KeyServer, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { - c := &KeyChangeConsumer{ - ctx: process.Context(), - consumer: &internal.ContinualConsumer{ - Process: process, - ComponentName: "federationapi/keychange", - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Consumer: kafkaConsumer, - PartitionStore: store, - }, + return &KeyChangeConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.TopicFor("FederationAPIKeyChangeConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), queues: queues, db: store, serverName: cfg.Matrix.ServerName, rsAPI: rsAPI, } - c.consumer.ProcessMessage = c.onMessage - - return c } // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { - if err := t.consumer.Start(); err != nil { - return fmt.Errorf("t.consumer.Start: %w", err) - } - return nil + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -102,9 +95,9 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { } } -func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { if m.DeviceKeys == nil { - return nil + return true } logger := logrus.WithField("user_id", m.UserID) @@ -112,10 +105,10 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) if err != nil { logger.WithError(err).Error("Failed to extract domain from key change event") - return nil + return true } if originServerName != t.serverName { - return nil + return true } var queryRes roomserverAPI.QueryRoomsForUserResponse @@ -125,15 +118,18 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") - return nil + return true } + if len(destinations) == 0 { + return true + } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDeviceListUpdate, @@ -149,24 +145,26 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { Keys: m.KeyJSON, } if edu.Content, err = json.Marshal(event); err != nil { - return err + logger.WithError(err).Error("failed to marshal EDU JSON") + return true } - logrus.Infof("Sending device list update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + logger.Debugf("Sending device list update message to %q", destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } -func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") - return nil + return true } if host != gomatrixserverlib.ServerName(t.serverName) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. - return nil + return true } logger := logrus.WithField("user_id", output.UserID) @@ -177,13 +175,17 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") - return nil + return true + } + + if len(destinations) == 0 { + return true } // Pack the EDU and marshal it @@ -193,11 +195,12 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { } if edu.Content, err = json.Marshal(output); err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping") - return nil + return true } - logger.Infof("Sending cross-signing update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + logger.Debugf("Sending cross-signing update message to %q", destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } func prevID(streamID int) []int { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 25ea78274..173dcff01 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -19,6 +19,10 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" @@ -26,9 +30,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. @@ -37,7 +38,7 @@ type OutputRoomEventConsumer struct { cfg *config.FederationAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues topic string @@ -66,74 +67,63 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the federation server receives a new event from the room server output log. // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } - switch output.Type { - case api.OutputTypeNewRoomEvent: - ev := output.NewRoomEvent.Event + switch output.Type { + case api.OutputTypeNewRoomEvent: + ev := output.NewRoomEvent.Event - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return false - } - } - - if err := s.processMessage(*output.NewRoomEvent); err != nil { - switch err.(type) { - case *queue.ErrorFederationDisabled: - log.WithField("error", output.Type).Info( - err.Error(), - ) - default: - // panic rather than continue with an inconsistent database - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "event": string(ev.JSON()), - "add": output.NewRoomEvent.AddsStateEventIDs, - "del": output.NewRoomEvent.RemovesStateEventIDs, - log.ErrorKey: err, - }).Panicf("roomserver output log: write room event failure") - } - } - - case api.OutputTypeNewInboundPeek: - if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { - log.WithFields(log.Fields{ - "event": output.NewInboundPeek, - log.ErrorKey: err, - }).Panicf("roomserver output log: remote peek event failure") + if output.NewRoomEvent.RewritesState { + if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { + log.WithError(err).Errorf("roomserver output log: purge room state failure") return false } - - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) } - return true - }) + if err := s.processMessage(*output.NewRoomEvent); err != nil { + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "event": string(ev.JSON()), + "add": output.NewRoomEvent.AddsStateEventIDs, + "del": output.NewRoomEvent.RemovesStateEventIDs, + log.ErrorKey: err, + }).Panicf("roomserver output log: write room event failure") + } + + case api.OutputTypeNewInboundPeek: + if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { + log.WithFields(log.Fields{ + "event": output.NewInboundPeek, + log.ErrorKey: err, + }).Panicf("roomserver output log: remote peek event failure") + return false + } + + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + + return true } // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 63387b9d8..a982d8009 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -92,7 +92,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -120,7 +120,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI, + base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 4dd53c11b..c51ecf146 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -196,29 +196,23 @@ func (r *FederationInternalAPI) performJoinUsingServer( return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) } - // No longer reuse the request context from this point forward. - // We don't want the client timing out to interrupt the join. - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(context.Background()) - // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( - ctx, + context.Background(), serverName, event, - respMakeJoin.RoomVersion, ) if err != nil { r.statistics.ForServer(serverName).Failure() - cancel() return fmt.Errorf("r.federation.SendJoin: %w", err) } r.statistics.ForServer(serverName).Success() + authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion) + // Sanity-check the join response to ensure that it has a create // event, that the room version is known, etc. - if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { - cancel() + if err = sanityCheckAuthChain(authEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } @@ -227,41 +221,36 @@ func (r *FederationInternalAPI) performJoinUsingServer( // to complete, but if the client does give up waiting, we'll // still continue to process the join anyway so that we don't // waste the effort. - go func() { - defer cancel() + // TODO: Can we expand Check here to return a list of missing auth + // events rather than failing one at a time? + var respState *gomatrixserverlib.RespState + respState, err = respSendJoin.Check( + context.Background(), + respMakeJoin.RoomVersion, + r.keyRing, + event, + federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + ) + if err != nil { + return fmt.Errorf("respSendJoin.Check: %w", err) + } - // TODO: Can we expand Check here to return a list of missing auth - // events rather than failing one at a time? - respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) - if err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to process room join response") - return - } + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + if err = roomserverAPI.SendEventWithState( + context.Background(), + r.rsAPI, + roomserverAPI.KindNew, + respState, + event.Headered(respMakeJoin.RoomVersion), + serverName, + nil, + false, + ); err != nil { + return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err) + } - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, - roomserverAPI.KindNew, - respState, - event.Headered(respMakeJoin.RoomVersion), - serverName, - nil, - false, - ); err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to send room join response to roomserver") - return - } - }() - - <-ctx.Done() return nil } @@ -405,12 +394,13 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( ctx = context.Background() respState := respPeek.ToRespState() + authEvents := respState.AuthEvents.UntrustedEvents(respPeek.RoomVersion) // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - if err = sanityCheckAuthChain(respState.AuthEvents); err != nil { + if err = sanityCheckAuthChain(authEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } - if err = respState.Check(ctx, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil { + if err = respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -562,10 +552,15 @@ func (r *FederationInternalAPI) PerformInvite( inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2: %w", err) + return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } + logrus.Infof("GOT INVITE RESPONSE %s", string(inviteRes.Event)) - response.Event = inviteRes.Event.Headered(request.RoomVersion) + inviteEvent, err := inviteRes.Event.UntrustedEvent(request.RoomVersion) + if err != nil { + return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) + } + response.Event = inviteEvent.Headered(request.RoomVersion) return nil } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index a65df906f..f9b2a33d2 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -387,14 +387,7 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( if request.Err != nil { return res, request.Err } - res.Events = make([]*gomatrixserverlib.Event, 0, len(request.Res.Events)) - for _, js := range request.Res.Events { - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(js, roomVersion) - if err != nil { - return res, err - } - res.Events = append(res.Events, ev) - } + res.Events = request.Res.Events return res, nil } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 1306e8588..09814b31f 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -297,7 +297,7 @@ func (oq *destinationQueue) backgroundSend() { // We haven't backed off yet, so wait for the suggested amount of // time. duration := time.Until(*until) - logrus.Warnf("Backing off %q for %s", oq.destination, duration) + logrus.Debugf("Backing off %q for %s", oq.destination, duration) oq.backingOff.Store(true) destinationQueueBackingOff.Inc() select { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8a6ad1555..dcd090856 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -22,15 +22,16 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" ) // OutgoingQueues is a collection of queues for sending transactions to other @@ -182,23 +183,14 @@ func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { destinationQueueTotal.Dec() } -type ErrorFederationDisabled struct { - Message string -} - -func (e *ErrorFederationDisabled) Error() string { - return e.Message -} - // SendEvent sends an event to the destinations func (oqs *OutgoingQueues) SendEvent( ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName, ) error { if oqs.disabled { - return &ErrorFederationDisabled{ - Message: "Federation disabled", - } + log.Trace("Federation is disabled, not sending event") + return nil } if origin != oqs.origin { // TODO: Support virtual hosting; gh issue #577. @@ -262,9 +254,8 @@ func (oqs *OutgoingQueues) SendEDU( destinations []gomatrixserverlib.ServerName, ) error { if oqs.disabled { - return &ErrorFederationDisabled{ - Message: "Federation disabled", - } + log.Trace("Federation is disabled, not sending EDU") + return nil } if origin != oqs.origin { // TODO: Support virtual hosting; gh issue #577. diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index d92b66f4b..0a03a0cb4 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -65,7 +65,7 @@ func GetEventAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: gomatrixserverlib.RespEventAuth{ - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), }, } } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 468659651..58bf99f4a 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -178,12 +178,12 @@ func processInvite( if isInviteV2 { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInviteV2{Event: &signedEvent}, + JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent.JSON()}, } } else { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: &signedEvent}, + JSON: gomatrixserverlib.RespInvite{Event: signedEvent.JSON()}, } } default: diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 7f8d31505..495b8c914 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -351,8 +351,8 @@ func SendJoin( return util.JSONResponse{ Code: http.StatusOK, JSON: gomatrixserverlib.RespSendJoin{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.AuthChainEvents), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, }, } diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index f79a2d2d8..dd3df7aa9 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -62,7 +62,7 @@ func GetMissingEvents( eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) resp := gomatrixserverlib.RespMissingEvents{ - Events: gomatrixserverlib.UnwrapEventHeaders(eventsResponse.Events), + Events: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), } return util.JSONResponse{ diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index 511329997..827d1116d 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -88,8 +88,8 @@ func Peek( } respPeek := gomatrixserverlib.RespPeek{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(response.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.StateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), RoomVersion: response.RoomVersion, LatestEvent: response.LatestEvent.Unwrap(), RenewalInterval: renewalInterval, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index dbfd3ff92..745e36de9 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -162,7 +162,7 @@ func Send( t.TransactionID = txnID t.Destination = cfg.Matrix.ServerName - util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) + util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { @@ -221,7 +221,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) return "" } roomVersions[roomID] = verRes.RoomVersion @@ -234,7 +234,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res RoomID string `json:"room_id"` } if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to extract room ID from event") + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") // We don't know the event ID at this point so we can't return the // failure in the PDU results continue @@ -255,7 +255,10 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res JSON: jsonerror.BadJSON("PDU contains bad JSON"), } } - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { continue } if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { @@ -265,7 +268,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res continue } if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -287,7 +290,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res nil, true, ); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -314,16 +317,16 @@ func (t *txnReq) processEDUs(ctx context.Context) { Typing bool `json:"typing"` } if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") continue } _, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID) if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to split domain from typing event sender") + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from typing event sender") continue } if domain != t.Origin { - util.GetLogger(ctx).Warnf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + util.GetLogger(ctx).Debugf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) continue } if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { @@ -333,7 +336,7 @@ func (t *txnReq) processEDUs(ctx context.Context) { // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal send-to-device events") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") continue } for userID, byUser := range directPayload.Messages { @@ -355,7 +358,7 @@ func (t *txnReq) processEDUs(ctx context.Context) { payload := map[string]eduserverAPI.FederationReceiptMRead{} if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal receipt event") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") continue } @@ -363,11 +366,11 @@ func (t *txnReq) processEDUs(ctx context.Context) { for userID, mread := range receipt.User { _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to split domain from receipt event sender") + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") continue } if t.Origin != domain { - util.GetLogger(ctx).Warnf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) continue } if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { @@ -382,20 +385,8 @@ func (t *txnReq) processEDUs(ctx context.Context) { } } case eduserverAPI.MSigningKeyUpdate: - var updatePayload eduserverAPI.CrossSigningKeyUpdate - if err := json.Unmarshal(e.Content, &updatePayload); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "user_id": updatePayload.UserID, - }).Error("Failed to send signing key update to edu server") - continue - } - inputReq := &eduserverAPI.InputCrossSigningKeyUpdateRequest{ - CrossSigningKeyUpdate: updatePayload, - } - inputRes := &eduserverAPI.InputCrossSigningKeyUpdateResponse{} - if err := t.eduAPI.InputCrossSigningKeyUpdate(ctx, inputReq, inputRes); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal cross-signing update") - continue + if err := t.processSigningKeyUpdate(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process signing key update") } default: util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") @@ -403,6 +394,34 @@ func (t *txnReq) processEDUs(ctx context.Context) { } } +func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverlib.EDU) error { + var updatePayload eduserverAPI.CrossSigningKeyUpdate + if err := json.Unmarshal(e.Content, &updatePayload); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "user_id": updatePayload.UserID, + }).Debug("Failed to unmarshal signing key update") + return err + } + + keys := gomatrixserverlib.CrossSigningKeys{} + if updatePayload.MasterKey != nil { + keys.MasterKey = *updatePayload.MasterKey + } + if updatePayload.SelfSigningKey != nil { + keys.SelfSigningKey = *updatePayload.SelfSigningKey + } + uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ + CrossSigningKeys: keys, + UserID: updatePayload.UserID, + } + uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} + t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + if uploadRes.Error != nil { + return uploadRes.Error + } + return nil +} + // processReceiptEvent sends receipt events to the edu server func (t *txnReq) processReceiptEvent(ctx context.Context, userID, roomID, receiptType string, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index f1f6169d9..4280643e9 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -93,11 +93,10 @@ func (o *testEDUProducer) InputCrossSigningKeyUpdate( type testRoomserverAPI struct { api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) InputRoomEvents( @@ -140,20 +139,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryMissingAuthPrevEvents(request) - response.RoomExists = res.RoomExists - response.MissingAuthEventIDs = res.MissingAuthEventIDs - response.MissingPrevEventIDs = res.MissingPrevEventIDs - return nil -} - // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -312,15 +297,7 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomat // The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } @@ -332,15 +309,7 @@ func TestBasicTransaction(t *testing.T) { // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 128df6187..37cbb9d1e 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -35,12 +35,15 @@ func GetState( return *err } - state, err := getState(ctx, request, rsAPI, roomID, eventID) + stateEvents, authChain, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } - return util.JSONResponse{Code: http.StatusOK, JSON: state} + return util.JSONResponse{Code: http.StatusOK, JSON: &gomatrixserverlib.RespState{ + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(authChain), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateEvents), + }} } // GetStateIDs returns state event IDs & auth event IDs for the roomID, eventID @@ -55,13 +58,13 @@ func GetStateIDs( return *err } - state, err := getState(ctx, request, rsAPI, roomID, eventID) + stateEvents, authEvents, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } - stateEventIDs := getIDsFromEvent(state.StateEvents) - authEventIDs := getIDsFromEvent(state.AuthEvents) + stateEventIDs := getIDsFromEvent(stateEvents) + authEventIDs := getIDsFromEvent(authEvents) return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.RespStateIDs{ StateEventIDs: stateEventIDs, @@ -97,18 +100,18 @@ func getState( rsAPI api.RoomserverInternalAPI, roomID string, eventID string, -) (*gomatrixserverlib.RespState, *util.JSONResponse) { +) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { event, resErr := fetchEvent(ctx, rsAPI, eventID) if resErr != nil { - return nil, resErr + return nil, nil, resErr } if event.RoomID() != roomID { - return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} } resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) if resErr != nil { - return nil, resErr + return nil, nil, resErr } var response api.QueryStateAndAuthChainResponse @@ -123,20 +126,17 @@ func getState( ) if err != nil { resErr := util.ErrorResponse(err) - return nil, &resErr + return nil, nil, &resErr } if !response.RoomExists { - return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(response.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), - }, nil + return response.StateEvents, response.AuthChainEvents, nil } -func getIDsFromEvent(events []*gomatrixserverlib.Event) []string { +func getIDsFromEvent(events []*gomatrixserverlib.HeaderedEvent) []string { IDs := make([]string, len(events)) for i := range events { IDs[i] = events[i].EventID() diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index b16c68d25..8ae7130c3 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -170,13 +170,18 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() } + inviteEvent, err := signedEvent.Event.UntrustedEvent(verRes.RoomVersion) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") + return jsonerror.InternalServerError() + } // Send the event to the roomserver if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{ - signedEvent.Event.Headered(verRes.RoomVersion), + inviteEvent.Headered(verRes.RoomVersion), }, request.Origin(), cfg.Matrix.ServerName, diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 21a919f6a..3fa8d1f7a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,12 +19,10 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) diff --git a/go.mod b/go.mod index 6d482bd60..2316096df 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c @@ -11,12 +11,11 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 - github.com/Shopify/sarama v1.29.0 github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible github.com/docker/go-connections v0.4.0 + github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 github.com/gorilla/mux v1.8.0 @@ -40,13 +39,13 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect github.com/nats-io/nats-server/v2 v2.3.2 - github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc + github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 @@ -54,22 +53,24 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0+incompatible - github.com/prometheus/client_golang v1.12.0 + github.com/prometheus/client_golang v1.12.1 github.com/sirupsen/logrus v1.8.1 - github.com/tidwall/gjson v1.13.0 + github.com/tidwall/gjson v1.14.0 github.com/tidwall/sjson v1.2.4 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 + golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect nhooyr.io/websocket v1.8.7 ) -go 1.15 +go 1.16 diff --git a/go.sum b/go.sum index 3ef5a54aa..e79015e51 100644 --- a/go.sum +++ b/go.sum @@ -100,14 +100,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= github.com/RyanCarrier/dijkstra v1.0.0/go.mod h1:5agGUBNEtUAGIANmbw09fuO3a2htPEkc1jNH01qxCWA= github.com/RyanCarrier/dijkstra-1 v0.0.0-20170512020943-0e5801a26345/go.mod h1:OK4EvWJ441LQqGzed5NGB6vKBAE34n3z7iayPcEwr30= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 h1:i3fOph9Hjleo6LbuqN9ODFxnwt7mOtYMpCGeC8qJN50= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32/go.mod h1:ne+jkLlzafIzaE4Q0Ze81T27dNgXe1wxovVEoAtSHTc= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ= -github.com/Shopify/sarama v1.29.0 h1:ARid8o8oieau9XrHI55f/L3EoRAhm9px6sonbD7yuUE= -github.com/Shopify/sarama v1.29.0/go.mod h1:2QpgD79wpdAESqNQMxNc0KYMkycd4slxGdV3TWSVqrU= -github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= @@ -350,12 +344,6 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q= -github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -376,13 +364,12 @@ github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNp github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= -github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= +github.com/frankban/quicktest v1.14.0 h1:+cqqvzZV87b4adx/5ayVOaYZ2CrvM4ejQvUdBzPPUss= +github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -425,7 +412,6 @@ github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL9 github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= @@ -494,7 +480,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= @@ -512,8 +497,9 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -547,10 +533,6 @@ github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -577,8 +559,6 @@ github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHh github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= -github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -664,18 +644,6 @@ github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsj github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o= github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= -github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= -github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= -github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= -github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= -github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= -github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= -github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJzodkA= -github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= -github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= -github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= @@ -744,8 +712,6 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= @@ -758,8 +724,9 @@ github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfo github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= @@ -1016,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32 h1:DiWPsGAYMlBQq/urm7TJkIeSf9FnfzegcaQUpgwIbUs= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1155,19 +1122,18 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE= -github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= @@ -1246,8 +1212,6 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= -github.com/pierrec/lz4 v2.6.0+incompatible h1:Ix9yFKn1nSPBLFl/yZknTp8TU5G4Ps0JDmguYK6iH1A= -github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1268,8 +1232,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= -github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= +github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -1301,12 +1265,12 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -1400,8 +1364,8 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.0 h1:6aeJ0bzojgWLa82gDQHcx3S0Lr/O51I9bJ5nv6JFx5w= +github.com/tidwall/gjson v1.14.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -1456,8 +1420,6 @@ github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPyS github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= -github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= @@ -1540,17 +1502,16 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 h1:kACShD3qhmr/3rLmg1yXyt+N4HcwutKyPRB93s54TIU= -golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= +golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1564,7 +1525,6 @@ golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= @@ -1646,7 +1606,6 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -1776,8 +1735,10 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= +golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= @@ -1796,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1873,9 +1834,7 @@ golang.zx2c4.com/wireguard v0.0.0-20210604143328-f9b48a961cd2/go.mod h1:laHzsbfM golang.zx2c4.com/wireguard v0.0.0-20210927201915-bb745b2ea326/go.mod h1:SDoazCvdy7RDjBPNEMBwrXhomlmtG7svs8mgwWEqtVI= golang.zx2c4.com/wireguard/windows v0.3.14/go.mod h1:3P4IEAsb+BjlKZmpUXgy74c0iX9AVwwr3WcVJ8nPgME= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index bf4fe85ed..6d413093f 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -7,14 +7,6 @@ import ( ) const ( - RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids" - RoomServerStateKeyNIDsCacheMaxEntries = 1024 - RoomServerStateKeyNIDsCacheMutable = false - - RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids" - RoomServerEventTypeNIDsCacheMaxEntries = 64 - RoomServerEventTypeNIDsCacheMutable = false - RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false @@ -29,44 +21,10 @@ type RoomServerCaches interface { // RoomServerNIDsCache contains the subset of functions needed for // a roomserver NID cache. type RoomServerNIDsCache interface { - GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) - StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) - - GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) - StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) - GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) } -func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) { - val, found := c.RoomServerStateKeyNIDs.Get(stateKey) - if found && val != nil { - if stateKeyNID, ok := val.(types.EventStateKeyNID); ok { - return stateKeyNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) { - c.RoomServerStateKeyNIDs.Set(stateKey, nid) -} - -func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) { - val, found := c.RoomServerEventTypeNIDs.Get(eventType) - if found && val != nil { - if eventTypeNID, ok := val.(types.EventTypeNID); ok { - return eventTypeNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) { - c.RoomServerEventTypeNIDs.Set(eventType, nid) -} - func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) if found && val != nil { diff --git a/internal/caching/caches.go b/internal/caching/caches.go index f04d05d42..e1642a663 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -4,14 +4,12 @@ package caching // different implementations as long as they satisfy the Cache // interface. type Caches struct { - RoomVersions Cache // RoomVersionCache - ServerKeys Cache // ServerKeyCache - RoomServerStateKeyNIDs Cache // RoomServerNIDsCache - RoomServerEventTypeNIDs Cache // RoomServerNIDsCache - RoomServerRoomNIDs Cache // RoomServerNIDsCache - RoomServerRoomIDs Cache // RoomServerNIDsCache - RoomInfos Cache // RoomInfoCache - FederationEvents Cache // FederationEventsCache + RoomVersions Cache // RoomVersionCache + ServerKeys Cache // ServerKeyCache + RoomServerRoomNIDs Cache // RoomServerNIDsCache + RoomServerRoomIDs Cache // RoomServerNIDsCache + RoomInfos Cache // RoomInfoCache + FederationEvents Cache // FederationEventsCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index f0915d7ca..ccb92852b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } - roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition( - RoomServerStateKeyNIDsCacheName, - RoomServerStateKeyNIDsCacheMutable, - RoomServerStateKeyNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } - roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition( - RoomServerEventTypeNIDsCacheName, - RoomServerEventTypeNIDsCacheMutable, - RoomServerEventTypeNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } roomServerRoomIDs, err := NewInMemoryLRUCachePartition( RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, @@ -74,18 +56,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { return nil, err } go cacheCleaner( - roomVersions, serverKeys, roomServerStateKeyNIDs, - roomServerEventTypeNIDs, roomServerRoomIDs, + roomVersions, serverKeys, roomServerRoomIDs, roomInfos, federationEvents, ) return &Caches{ - RoomVersions: roomVersions, - ServerKeys: serverKeys, - RoomServerStateKeyNIDs: roomServerStateKeyNIDs, - RoomServerEventTypeNIDs: roomServerEventTypeNIDs, - RoomServerRoomIDs: roomServerRoomIDs, - RoomInfos: roomInfos, - FederationEvents: federationEvents, + RoomVersions: roomVersions, + ServerKeys: serverKeys, + RoomServerRoomIDs: roomServerRoomIDs, + RoomInfos: roomInfos, + FederationEvents: federationEvents, }, nil } diff --git a/internal/consumers.go b/internal/consumers.go deleted file mode 100644 index 3a4e0b7f8..000000000 --- a/internal/consumers.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "fmt" - - "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" -) - -// A PartitionStorer has the storage APIs needed by the consumer. -type PartitionStorer interface { - // PartitionOffsets returns the offsets the consumer has reached for each partition. - PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error) - // SetPartitionOffset records where the consumer has reached for a partition. - SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error -} - -// A ContinualConsumer continually consumes logs even across restarts. It requires a PartitionStorer to -// remember the offset it reached. -type ContinualConsumer struct { - // The parent context for the listener, stop consuming when this context is done - Process *process.ProcessContext - // The component name - ComponentName string - // The kafkaesque topic to consume events from. - // This is the name used in kafka to identify the stream to consume events from. - Topic string - // A kafkaesque stream consumer providing the APIs for talking to the event source. - // The interface is taken from a client library for Apache Kafka. - // But any equivalent event streaming protocol could be made to implement the same interface. - Consumer sarama.Consumer - // A thing which can load and save partition offsets for a topic. - PartitionStore PartitionStorer - // ProcessMessage is a function which will be called for each message in the log. Return an error to - // stop processing messages. See ErrShutdown for specific control signals. - ProcessMessage func(msg *sarama.ConsumerMessage) error - // ShutdownCallback is called when ProcessMessage returns ErrShutdown, after the partition has been saved. - // It is optional. - ShutdownCallback func() -} - -// ErrShutdown can be returned from ContinualConsumer.ProcessMessage to stop the ContinualConsumer. -var ErrShutdown = fmt.Errorf("shutdown") - -// Start starts the consumer consuming. -// Starts up a goroutine for each partition in the kafka stream. -// Returns nil once all the goroutines are started. -// Returns an error if it can't start consuming for any of the partitions. -func (c *ContinualConsumer) Start() error { - _, err := c.StartOffsets() - return err -} - -// StartOffsets is the same as Start but returns the loaded offsets as well. -func (c *ContinualConsumer) StartOffsets() ([]sqlutil.PartitionOffset, error) { - offsets := map[int32]int64{} - - partitions, err := c.Consumer.Partitions(c.Topic) - if err != nil { - return nil, err - } - for _, partition := range partitions { - // Default all the offsets to the beginning of the stream. - offsets[partition] = sarama.OffsetOldest - } - - storedOffsets, err := c.PartitionStore.PartitionOffsets(context.TODO(), c.Topic) - if err != nil { - return nil, err - } - for _, offset := range storedOffsets { - // We've already processed events from this partition so advance the offset to where we got to. - // ConsumePartition will start streaming from the message with the given offset (inclusive), - // so increment 1 to avoid getting the same message a second time. - offsets[offset.Partition] = 1 + offset.Offset - } - - var partitionConsumers []sarama.PartitionConsumer - for partition, offset := range offsets { - pc, err := c.Consumer.ConsumePartition(c.Topic, partition, offset) - if err != nil { - for _, p := range partitionConsumers { - p.Close() // nolint: errcheck - } - return nil, err - } - partitionConsumers = append(partitionConsumers, pc) - } - for _, pc := range partitionConsumers { - go c.consumePartition(pc) - if c.Process != nil { - c.Process.ComponentStarted() - go func(pc sarama.PartitionConsumer) { - <-c.Process.WaitForShutdown() - _ = pc.Close() - c.Process.ComponentFinished() - logrus.Infof("Stopped consumer for %q topic %q", c.ComponentName, c.Topic) - }(pc) - } - } - - return storedOffsets, nil -} - -// consumePartition consumes the room events for a single partition of the kafkaesque stream. -func (c *ContinualConsumer) consumePartition(pc sarama.PartitionConsumer) { - defer pc.Close() // nolint: errcheck - for message := range pc.Messages() { - msgErr := c.ProcessMessage(message) - // Advance our position in the stream so that we will start at the right position after a restart. - if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil { - panic(fmt.Errorf("the ContinualConsumer in %q failed to SetPartitionOffset: %w", c.ComponentName, err)) - } - // Shutdown if we were told to do so. - if msgErr == ErrShutdown { - if c.ShutdownCallback != nil { - c.ShutdownCallback() - } - return - } - } -} diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 1fbd77da9..1a37a1eec 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -53,12 +53,13 @@ func MakeAuthAPI( f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { + logger := util.GetLogger(req.Context()) device, err := auth.VerifyUserFromRequest(req, userAPI) if err != nil { + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) return *err } // add the user ID to the logger - logger := util.GetLogger((req.Context())) logger = logger.WithField("user_id", device.UserID) req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) // add the user to Sentry, if enabled diff --git a/internal/test/config.go b/internal/test/config.go index bb2f8a4c6..0372fb9c6 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/pem" + "errors" "fmt" "io/ioutil" "math/big" @@ -94,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database) cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress() @@ -158,11 +158,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) { const certificateDuration = time.Hour * 24 * 365 * 10 -// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. -func NewTLSKey(tlsKeyPath, tlsCertPath string) error { +func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { - return err + return nil, nil, err } notBefore := time.Now() @@ -170,7 +169,7 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return err + return nil, nil, err } template := x509.Certificate{ @@ -180,20 +179,21 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, + DNSNames: dnsNames, } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return err - } + return priv, &template, nil +} + +func writeCertificate(tlsCertPath string, derBytes []byte) error { certOut, err := os.Create(tlsCertPath) if err != nil { return err } defer certOut.Close() // nolint: errcheck - if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - return err - } + return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) +} +func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error { keyOut, err := os.OpenFile(tlsKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err @@ -205,3 +205,73 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { }) return err } + +// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. +func NewTLSKey(tlsKeyPath, tlsCertPath string) error { + priv, template, err := generateTLSTemplate(nil) + if err != nil { + return err + } + + // Self-signed certificate: template == parent + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} + +func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error { + priv, template, err := generateTLSTemplate([]string{serverName}) + if err != nil { + return err + } + + // load the authority key + dat, err := ioutil.ReadFile(authorityKeyPath) + if err != nil { + return err + } + block, _ := pem.Decode([]byte(dat)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return errors.New("authority .key is not a valid pem encoded rsa private key") + } + authorityPriv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + + // load the authority certificate + dat, err = ioutil.ReadFile(authorityCertPath) + if err != nil { + return err + } + block, _ = pem.Decode([]byte(dat)) + if block == nil || block.Type != "CERTIFICATE" { + return errors.New("authority .crt is not a valid pem encoded x509 cert") + } + var caCerts []*x509.Certificate + caCerts, err = x509.ParseCertificates(block.Bytes) + if err != nil { + return err + } + if len(caCerts) != 1 { + return errors.New("authority .crt contains none or more than one cert") + } + authorityCert := caCerts[0] + + // Sign the new certificate using the authority's key/cert + derBytes, err := x509.CreateCertificate(rand.Reader, template, authorityCert, &priv.PublicKey, authorityPriv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} diff --git a/internal/version.go b/internal/version.go index 88123693f..a07f01b61 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 5 - VersionPatch = 1 + VersionMinor = 6 + VersionPatch = 3 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 0eea2f0fa..3933961c1 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -228,7 +228,7 @@ type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use sarama.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). ToOffset int64 } diff --git a/keyserver/consumers/cross_signing.go b/keyserver/consumers/cross_signing.go deleted file mode 100644 index 4b2bd4a9a..000000000 --- a/keyserver/consumers/cross_signing.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - - "github.com/Shopify/sarama" -) - -type OutputCrossSigningKeyUpdateConsumer struct { - eduServerConsumer *internal.ContinualConsumer - keyDB storage.Database - keyAPI api.KeyInternalAPI - serverName string -} - -func NewOutputCrossSigningKeyUpdateConsumer( - process *process.ProcessContext, - cfg *config.Dendrite, - kafkaConsumer sarama.Consumer, - keyDB storage.Database, - keyAPI api.KeyInternalAPI, -) *OutputCrossSigningKeyUpdateConsumer { - // The keyserver both produces and consumes on the TopicOutputKeyChangeEvent - // topic. We will only produce events where the UserID matches our server name, - // and we will only consume events where the UserID does NOT match our server - // name (because the update came from a remote server). - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "keyserver/keyserver", - Topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - Consumer: kafkaConsumer, - PartitionStore: keyDB, - } - s := &OutputCrossSigningKeyUpdateConsumer{ - eduServerConsumer: &consumer, - keyDB: keyDB, - keyAPI: keyAPI, - serverName: string(cfg.Global.ServerName), - } - consumer.ProcessMessage = s.onMessage - - return s -} - -func (s *OutputCrossSigningKeyUpdateConsumer) Start() error { - return s.eduServerConsumer.Start() -} - -// onMessage is called in response to a message received on the -// key change events topic from the key server. -func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error { - var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { - logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil - } - if m.OutputCrossSigningKeyUpdate == nil { - // This probably shouldn't happen but stops us from panicking if we come - // across an update that doesn't satisfy either types. - return nil - } - switch m.Type { - case api.TypeCrossSigningUpdate: - return t.onCrossSigningMessage(m) - default: - return nil - } -} - -func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error { - output := m.CrossSigningKeyUpdate - _, host, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - logrus.WithError(err).Errorf("eduserver output log: user ID parse failure") - return nil - } - if host == gomatrixserverlib.ServerName(s.serverName) { - // Ignore any messages that contain information about our own users, as - // they already originated from this server. - return nil - } - uploadReq := &api.PerformUploadDeviceKeysRequest{ - UserID: output.UserID, - } - if output.MasterKey != nil { - uploadReq.MasterKey = *output.MasterKey - } - if output.SelfSigningKey != nil { - uploadReq.SelfSigningKey = *output.SelfSigningKey - } - uploadRes := &api.PerformUploadDeviceKeysResponse{} - s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes) - return uploadRes.Error -} diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 1e1871b8b..bfb2037f8 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -219,25 +219,23 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // Finally, generate a notification that we updated the keys. - if _, host, err := gomatrixserverlib.SplitID('@', req.UserID); err == nil && host == a.ThisServer { - update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: req.UserID, - } - if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { - update.MasterKey = &mk - } - if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { - update.SelfSigningKey = &ssk - } - if update.MasterKey == nil && update.SelfSigningKey == nil { - return - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return + update := eduserverAPI.CrossSigningKeyUpdate{ + UserID: req.UserID, + } + if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { + update.MasterKey = &mk + } + if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { + update.SelfSigningKey = &ssk + } + if update.MasterKey == nil && update.SelfSigningKey == nil { + return + } + if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } + return } } @@ -310,16 +308,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req // Finally, generate a notification that we updated the signatures. for userID := range req.Signatures { - if _, host, err := gomatrixserverlib.SplitID('@', userID); err == nil && host == a.ThisServer { - update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: userID, - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return + masterKey := queryRes.MasterKeys[userID] + selfSigningKey := queryRes.SelfSigningKeys[userID] + update := eduserverAPI.CrossSigningKeyUpdate{ + UserID: userID, + MasterKey: &masterKey, + SelfSigningKey: &selfSigningKey, + } + if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } + return } } } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 1b6e2d428..c5a5d40c7 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -367,10 +367,13 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime = fcerr.RetryAfter } else if fcerr.Blacklisted { waitTime = time.Hour * 8 + } else { + // For all other errors (DNS resolution, network etc.) wait 1 hour. + waitTime = time.Hour } } else { waitTime = time.Hour - logger.WithError(err).Warn("GetUserDevices returned unknown error type") + logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") } continue } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 259249217..0c264b718 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne } func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domain := string(serverName) // query local devices if serverName == a.ThisServer { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -326,8 +326,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if err = json.Unmarshal(key, &deviceKey); err != nil { continue } + if deviceKey.Signatures == nil { + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := deviceKey.Signatures[sourceUserID]; !ok { + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } } @@ -447,7 +453,6 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( for userID, deviceIDs := range devKeys { if len(deviceIDs) == 0 { userIDsForAllDevices[userID] = struct{}{} - delete(devKeys, userID) } } // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing @@ -508,6 +513,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // drop the error as it's already a failure at this point _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) } + + // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + if len(res.DeviceKeys) > 0 { + delete(res.Failures, serverName) + } respMu.Unlock() } @@ -515,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -544,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( } func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &userapi.QueryDevicesResponse{} + if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) + } else { + logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + } + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users } @@ -558,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -573,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per }) } - // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceMessage, len(keysToStore)) - for i := range keysToStore { - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, - }, - } - } - if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } if req.OnlyDisplayNameUpdates { // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } // store the device keys and emit changes - err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 8cc50ea0d..bd36fd9f9 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/producers" @@ -40,7 +39,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { @@ -65,12 +64,5 @@ func NewInternalAPI( } }() - keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer( - base.ProcessContext, base.Cfg, consumer, db, ap, - ) - if err := keyconsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer") - } - return ap } diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index fd143c6cf..9e1c4c645 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -65,7 +65,7 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { logrus.WithFields(logrus.Fields{ "user_id": userID, "num_key_changes": count, - }).Infof("Produced to key change topic '%s'", p.Topic) + }).Tracef("Produced to key change topic '%s'", p.Topic) } return nil } @@ -103,6 +103,6 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er logrus.WithFields(logrus.Fields{ "user_id": key.UserID, - }).Infof("Produced to cross-signing update topic '%s'", p.Topic) + }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) return nil } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 87feae47d..4dffe695c 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -18,15 +18,12 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) @@ -56,7 +53,7 @@ type Database interface { // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. @@ -71,7 +68,7 @@ type Database interface { StoreKeyChange(ctx context.Context, userID string) (int64, error) // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of sarama.OffsetNewest means no upper limit. + // A to offset of types.OffsetNewest means no upper limit. // Returns the offset of the latest key change. KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 5ae0da969..628301cf7 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 20d227c24..f93a94bd3 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -17,9 +17,7 @@ package postgres import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -78,9 +76,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index cc397ba84..0b143a1aa 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -59,6 +59,9 @@ const deleteOneTimeKeySQL = "" + const selectKeyByAlgorithmSQL = "" + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + type oneTimeKeysStatements struct { db *sql.DB upsertKeysStmt *sql.Stmt @@ -66,6 +69,7 @@ type oneTimeKeysStatements struct { selectKeysCountStmt *sql.Stmt selectKeyByAlgorithmStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt } func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -91,6 +95,9 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { return nil, err } + if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -187,3 +194,8 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err } + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5914d28e1..f2790c8df 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) } func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { @@ -171,6 +171,9 @@ func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } } return nil }) diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index fa1c930db..b461424c6 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true } - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index d43c15ca9..e035e8c9c 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -17,9 +17,7 @@ package sqlite3 import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -76,9 +74,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 185b88612..897839aca 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -58,6 +58,9 @@ const deleteOneTimeKeySQL = "" + const selectKeyByAlgorithmSQL = "" + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + type oneTimeKeysStatements struct { db *sql.DB upsertKeysStmt *sql.Stmt @@ -65,6 +68,7 @@ type oneTimeKeysStatements struct { selectKeysCountStmt *sql.Stmt selectKeyByAlgorithmStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt } func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -90,6 +94,9 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { return nil, err } + if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -201,3 +208,8 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err } + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 2f8cf809b..4d5137249 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/setup/config" ) @@ -50,7 +50,7 @@ func TestKeyChanges(t *testing.T) { MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } @@ -74,7 +74,7 @@ func TestKeyChangesNoDupes(t *testing.T) { } deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } @@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 0d94c94cc..cd1719598 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -31,6 +31,7 @@ type OneTimeKeys interface { // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error } type DeviceKeys interface { @@ -38,7 +39,7 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } @@ -46,7 +47,7 @@ type DeviceKeys interface { type KeyChanges interface { InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) Prepare() error diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go index 3480ec65f..7fb90454e 100644 --- a/keyserver/types/storage.go +++ b/keyserver/types/storage.go @@ -14,7 +14,18 @@ package types -import "github.com/matrix-org/gomatrixserverlib" +import ( + "math" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + // OffsetNewest tells e.g. the database to get the most current data + OffsetNewest int64 = math.MaxInt64 + // OffsetOldest tells e.g. the database to get the oldest data + OffsetOldest int64 = 0 +) // KeyTypePurposeToInt maps a purpose to an integer, which is used in the // database to reduce the amount of space taken up by this column. diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 2358915ee..4ce738b6e 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -157,7 +157,7 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON // Set status code and write the body w.WriteHeader(res.Code) - r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) + r.Logger.WithField("code", res.Code).Tracef("Responding (%d bytes)", len(resBytes)) // we don't really care that much if we fail to write the error response w.Write(resBytes) // nolint: errcheck @@ -293,11 +293,11 @@ func (r *downloadRequest) respondFromLocalFile( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("No good thumbnail found. Responding with original file.") + }).Trace("No good thumbnail found. Responding with original file.") responseFile = file responseMetadata = r.MediaMetadata } else { - r.Logger.Info("Responding with thumbnail") + r.Logger.Trace("Responding with thumbnail") responseFile = thumbFile responseMetadata = thumbMetadata.MediaMetadata } @@ -307,7 +307,7 @@ func (r *downloadRequest) respondFromLocalFile( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("Responding with file") + }).Trace("Responding with file") responseFile = file responseMetadata = r.MediaMetadata if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { @@ -436,7 +436,7 @@ func (r *downloadRequest) getThumbnailFile( "Width": thumbnailSize.Width, "Height": thumbnailSize.Height, "ResizeMethod": thumbnailSize.ResizeMethod, - }).Info("Pre-generating thumbnail for immediate response.") + }).Debug("Pre-generating thumbnail for immediate response.") thumbnail, err = r.generateThumbnail( ctx, filePath, *thumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db, @@ -574,7 +574,7 @@ func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests defer activeRemoteRequests.Unlock() if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Waiting for another goroutine to fetch the remote file.") + r.Logger.Trace("Waiting for another goroutine to fetch the remote file.") // NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this. activeRemoteRequestResult.Cond.Wait() @@ -604,7 +604,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act defer activeRemoteRequests.Unlock() mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") + r.Logger.Trace("Signalling other goroutines waiting for this goroutine to fetch the file.") activeRemoteRequestResult.MediaMetadata = r.MediaMetadata activeRemoteRequestResult.Error = err activeRemoteRequestResult.Cond.Broadcast() @@ -635,7 +635,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( "UploadName": r.MediaMetadata.UploadName, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("Storing file metadata to media repository database") + }).Debug("Storing file metadata to media repository database") // FIXME: timeout db request if err := db.StoreMediaMetadata(ctx, r.MediaMetadata); err != nil { @@ -669,7 +669,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Infof("Remote file cached") + }).Debug("Remote file cached") return nil } @@ -717,7 +717,7 @@ func (r *downloadRequest) fetchRemoteFile( absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, ) (types.Path, bool, error) { - r.Logger.Info("Fetching remote file") + r.Logger.Debug("Fetching remote file") // create request for remote file resp, err := r.createRemoteRequest(ctx, client) @@ -762,7 +762,7 @@ func (r *downloadRequest) fetchRemoteFile( } } - r.Logger.Info("Transferring remote file") + r.Logger.Trace("Transferring remote file") // The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct @@ -776,7 +776,7 @@ func (r *downloadRequest) fetchRemoteFile( return "", false, errors.New("file could not be downloaded from remote server") } - r.Logger.Info("Remote file transferred") + r.Logger.Trace("Remote file transferred") // It's possible the bytesWritten to the temporary file is different to the reported Content-Length from the remote // request's response. bytesWritten is therefore used as it is what would be sent to clients when reading from the local @@ -790,7 +790,7 @@ func (r *downloadRequest) fetchRemoteFile( return "", false, fmt.Errorf("fileutils.MoveFileWithHashCheck: %w", err) } if duplicate { - r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate") + r.Logger.WithField("dst", finalPath).Trace("File was stored previously - discarding duplicate") // Continue on to store the metadata in the database } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index d35fd84df..bcbf0e4f9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -3,9 +3,11 @@ package api import ( "context" + "github.com/matrix-org/gomatrixserverlib" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInputAPI is used to write events to the room server. @@ -14,6 +16,7 @@ type RoomserverInternalAPI interface { // interdependencies between the roomserver and other input APIs SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) + SetUserAPI(userAPI userapi.UserInternalAPI) InputRoomEvents( ctx context.Context, @@ -83,13 +86,6 @@ type RoomserverInternalAPI interface { response *QueryStateAfterEventsResponse, ) error - // Query whether the roomserver is missing any auth or prev events. - QueryMissingAuthPrevEvents( - ctx context.Context, - request *QueryMissingAuthPrevEventsRequest, - response *QueryMissingAuthPrevEventsResponse, - ) error - // Query a list of events by event ID. QueryEventsByID( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 64cbaca49..88b372154 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -5,10 +5,12 @@ import ( "encoding/json" "fmt" - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + asAPI "github.com/matrix-org/dendrite/appservice/api" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the @@ -25,6 +27,10 @@ func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQuer t.Impl.SetAppserviceAPI(asAPI) } +func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) { + t.Impl.SetUserAPI(userAPI) +} + func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, @@ -129,16 +135,6 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( return err } -func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents( - ctx context.Context, - req *QueryMissingAuthPrevEventsRequest, - res *QueryMissingAuthPrevEventsResponse, -) error { - err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res)) - return err -} - func (t *RoomserverInternalAPITrace) QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 4b0704b9f..45a9ef497 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -42,6 +42,19 @@ const ( KindOld ) +func (k Kind) String() string { + switch k { + case KindOutlier: + return "KindOutlier" + case KindNew: + return "KindNew" + case KindOld: + return "KindOld" + default: + return "(unknown)" + } +} + // DoNotSendToOtherServers tells us not to send the event to other matrix // servers. const DoNotSendToOtherServers = "" diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 51cbcb1ad..d640858a6 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -95,6 +95,8 @@ type PerformLeaveRequest struct { } type PerformLeaveResponse struct { + Code int `json:"code,omitempty"` + Message interface{} `json:"message,omitempty"` } type PerformInviteRequest struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 283217157..96d6711c6 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -83,27 +83,6 @@ type QueryStateAfterEventsResponse struct { StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` } -type QueryMissingAuthPrevEventsRequest struct { - // The room ID to query the state in. - RoomID string `json:"room_id"` - // The list of auth events to check the existence of. - AuthEventIDs []string `json:"auth_event_ids"` - // The list of previous events to check the existence of. - PrevEventIDs []string `json:"prev_event_ids"` -} - -type QueryMissingAuthPrevEventsResponse struct { - // Does the room exist on this roomserver? - // If the room doesn't exist all other fields will be empty. - RoomExists bool `json:"room_exists"` - // The room version of the room. - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - // The event IDs of the auth events that we don't know locally. - MissingAuthEventIDs []string `json:"missing_auth_event_ids"` - // The event IDs of the previous events that we don't know locally. - MissingPrevEventIDs []string `json:"missing_prev_event_ids"` -} - // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { // The event IDs to look up. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index e9b94e48c..012094c62 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,7 +51,7 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events() + outliers, err := state.Events(event.RoomVersion) if err != nil { return err } @@ -68,9 +68,10 @@ func SendEventWithState( }) } - stateEventIDs := make([]string, len(state.StateEvents)) - for i := range state.StateEvents { - stateEventIDs[i] = state.StateEvents[i].EventID() + stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion) + stateEventIDs := make([]string, len(stateEvents)) + for i := range stateEvents { + stateEventIDs[i] = stateEvents[i].EventID() } ires = append(ires, InputRoomEvent{ diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 5b87e623d..10c8c844e 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -14,6 +14,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -32,6 +34,7 @@ type RoomserverInternalAPI struct { *perform.Publisher *perform.Backfiller *perform.Forgetter + ProcessContext *process.ProcessContext DB storage.Database Cfg *config.RoomServer Cache caching.RoomServerCaches @@ -41,19 +44,20 @@ type RoomserverInternalAPI struct { fsAPI fsAPI.FederationInternalAPI asAPI asAPI.AppServiceQueryAPI JetStream nats.JetStreamContext - Durable nats.SubOpt + Durable string InputRoomEventTopic string // JetStream topic for new input room events OutputRoomEventTopic string // JetStream topic for new output room events PerspectiveServerNames []gomatrixserverlib.ServerName } func NewRoomserverAPI( - cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext, - inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches, - perspectiveServerNames []gomatrixserverlib.ServerName, + processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database, + consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string, + caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName, ) *RoomserverInternalAPI { serverACLs := acls.NewServerACLs(roomserverDB) a := &RoomserverInternalAPI{ + ProcessContext: processCtx, DB: roomserverDB, Cfg: cfg, Cache: caches, @@ -83,11 +87,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA r.KeyRing = keyRing r.Inputer = &input.Inputer{ + ProcessContext: r.ProcessContext, DB: r.DB, InputRoomEventTopic: r.InputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic, JetStream: r.JetStream, - Durable: r.Durable, + Durable: nats.Durable(r.Durable), ServerName: r.Cfg.Matrix.ServerName, FSAPI: fsAPI, KeyRing: keyRing, @@ -155,6 +160,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA } } +func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { + r.Leaver.UserAPI = userAPI +} + func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { r.asAPI = asAPI } diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index ddda8081c..9af0bf591 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,17 +20,22 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) +type checkForAuthAndSoftFailStorage interface { + state.StateResolutionStorage + StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) +} + // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -92,7 +97,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, - db storage.Database, + db state.StateResolutionStorage, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e0ddd07cf..22e4b67a0 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,25 +19,40 @@ import ( "context" "encoding/json" "errors" + "fmt" "sync" "time" "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" fedapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) +type retryAction int +type commitAction int + +const ( + doNotRetry retryAction = iota + retryLater +) + +const ( + commitTransaction commitAction = iota + rollbackTransaction +) + var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -45,6 +60,7 @@ var keyContentFields = map[string]string{ } type Inputer struct { + ProcessContext *process.ProcessContext DB storage.Database JetStream nats.JetStreamContext Durable nats.SubOpt @@ -101,14 +117,23 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { + action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) + if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } - } else { - go hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver failed to process async event") + } + switch action { + case retryLater: + _ = msg.Nak() + case doNotRetry: + _ = msg.Ack() } - _ = msg.Ack() }) }, // NATS wants to acknowledge automatically by default when the message is @@ -123,11 +148,42 @@ func (r *Inputer) Start() error { nats.DeliverAll(), // Ensure that NATS doesn't try to resend us something that wasn't done // within the period of time that we might still be processing it. - nats.AckWait(MaximumProcessingTime+(time.Second*10)), + nats.AckWait(MaximumMissingProcessingTime+(time.Second*10)), ) return err } +// processRoomEventUsingUpdater opens up a room updater and tries to +// process the event. It returns whether or not we should positively +// or negatively acknowledge the event (i.e. for NATS) and an error +// if it occurred. +func (r *Inputer) processRoomEventUsingUpdater( + ctx context.Context, + roomID string, + inputRoomEvent *api.InputRoomEvent, +) (retryAction, error) { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) + } + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) + switch action { + case commitTransaction: + if cerr := updater.Commit(); cerr != nil { + return retryLater, fmt.Errorf("updater.Commit: %w", cerr) + } + case rollbackTransaction: + if rerr := updater.Rollback(); rerr != nil { + return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) + } + } + return doNotRetry, err +} + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -149,12 +205,15 @@ func (r *Inputer) InputRoomEvents( return } if _, err = r.JetStream.PublishMsg(msg); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": e.Event.EventID(), + }).Error("Roomserver failed to queue async event") return } } } else { responses := make(chan error, len(request.InputRoomEvents)) - defer close(responses) for _, e := range request.InputRoomEvents { inputRoomEvent := e roomID := inputRoomEvent.Event.RoomID() @@ -171,13 +230,15 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - err := r.processRoomEvent(ctx, &inputRoomEvent) + _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } - } else { - go hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": inputRoomEvent.Event.EventID(), + }).Warn("Roomserver failed to process sync event") } select { case <-ctx.Done(): diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 5f9115223..4e151699e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -25,9 +25,11 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,7 +42,7 @@ func init() { } // TODO: Does this value make sense? -const MaximumProcessingTime = time.Minute * 2 +const MaximumMissingProcessingTime = time.Minute * 2 var processRoomEventDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -65,25 +67,19 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // TODO: Break up function - we should probably do transaction ID checks before calling this. // nolint:gocyclo func (r *Inputer) processRoomEvent( - inctx context.Context, + ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (err error) { +) (commitAction, error) { select { - case <-inctx.Done(): + case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return context.DeadlineExceeded + return rollbackTransaction, context.DeadlineExceeded default: } - // Wrap the context with a time limit. We'll allow no more than MaximumProcessingTime for - // everything that we need to do for this event, or it's possible that we could end up wedging - // the roomserver for a very long time. - var cancel context.CancelFunc - ctx, cancel := context.WithTimeout(inctx, MaximumProcessingTime) - defer cancel() - // Measure how long it takes to process this event. started := time.Now() defer func() { @@ -99,13 +95,21 @@ func (r *Inputer) processRoomEvent( logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), "room_id": event.RoomID(), + "kind": input.Kind, + "origin": input.Origin, "type": event.Type(), }) + if input.HasState { + logger = logger.WithFields(logrus.Fields{ + "has_state": input.HasState, + "state_ids": len(input.StateEventIDs), + }) + } // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -114,39 +118,63 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } default: logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } } } } - missingRes := &api.QueryMissingAuthPrevEventsResponse{} - serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} - if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") { - missingReq := &api.QueryMissingAuthPrevEventsRequest{ - RoomID: event.RoomID(), - AuthEventIDs: event.AuthEventIDs(), - PrevEventIDs: event.PrevEventIDs(), - } - if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) - } + // Don't waste time processing the event if the room doesn't exist. + // A room entry locally will only be created in response to a create + // event. + isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + if !updater.RoomExists() && !isCreateEvent { + return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 { + + var missingAuth, missingPrev bool + serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} + if !isCreateEvent { + missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + if err != nil { + return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + } + missingAuth = len(missingAuthIDs) > 0 + missingPrev = !input.HasState && len(missingPrevIDs) > 0 + } + + if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), ExcludeSelf: true, } - if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + } + // Sort all of the servers into a map so that we can randomise + // their order. Then make sure that the input origin and the + // event origin are first on the list. + servers := map[gomatrixserverlib.ServerName]struct{}{} + for _, server := range serverRes.ServerNames { + servers[server] = struct{}{} + } + serverRes.ServerNames = serverRes.ServerNames[:0] + if input.Origin != "" { + serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + delete(servers, input.Origin) + } + if origin := event.Origin(); origin != input.Origin { + serverRes.ServerNames = append(serverRes.ServerNames, origin) + delete(servers, origin) + } + for server := range servers { + serverRes.ServerNames = append(serverRes.ServerNames, server) + delete(servers, server) } - } - if input.Origin != "" { - serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) } // First of all, check that the auth events of the event are known. @@ -154,8 +182,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.checkForMissingAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -163,7 +191,7 @@ func (r *Inputer) processRoomEvent( var rejectionErr error if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) } // Accumulate the auth event NIDs. @@ -171,18 +199,36 @@ func (r *Inputer) processRoomEvent( authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { if _, ok := knownEvents[authEventID]; !ok { - return fmt.Errorf("missing auth event %s", authEventID) + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return commitTransaction, fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return commitTransaction, fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } var softfail bool if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + var err error + softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) if err != nil { - logger.WithError(err).Info("Error authing soft-failed event") + logger.WithError(err).Warn("Error authing soft-failed event") } } @@ -196,7 +242,6 @@ func (r *Inputer) processRoomEvent( // typical federated room join) then we won't bother trying to fetch prev events // because we may not be allowed to see them and we have no choice but to trust // the state event IDs provided to us in the join instead. - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingPrev && input.Kind == api.KindNew { // Don't do this for KindOld events, otherwise old events that we fetch // to satisfy missing prev events/state will end up recursively calling @@ -205,41 +250,56 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - queryer: r.Queryer, - db: r.DB, + db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), - servers: map[gomatrixserverlib.ServerName]struct{}{}, + servers: serverRes.ServerNames, hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } - for _, serverName := range serverRes.ServerNames { - missingState.servers[serverName] = struct{}{} - } - if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + // Something went wrong with retrieving the missing state, so we can't + // really do anything with the event other than reject it at this point. isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) + } else if stateSnapshot != nil { + // We retrieved some state and we ended up having to call /state_ids for + // the new event in question (probably because closing the gap by using + // /get_missing_events didn't do what we hoped) so we'll instead overwrite + // the state snapshot with the newly resolved state. + missingPrev = false + input.HasState = true + input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents)) + for _, e := range stateSnapshot.StateEvents { + input.StateEventIDs = append(input.StateEventIDs, e.EventID()) + } } else { + // We retrieved some state and it would appear that rolling forward the + // state did everything we needed it to do, so we can just resolve the + // state for the event in the normal way. missingPrev = false } } else { + // We're missing prev events or state for the event, but for some reason + // we don't know any servers to ask. In this case we can't do anything but + // reject the event and hope that it gets unrejected later. isRejected = true rejectionErr = fmt.Errorf("missing prev events and no other servers to ask") } } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -249,36 +309,44 @@ func (r *Inputer) processRoomEvent( // notify anyone about it. if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") - return nil + hooks.Run(hooks.KindNewEventPersisted, headered) + return commitTransaction, nil } - roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } - if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { + if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return fmt.Errorf("r.calculateAndSetState: %w", err) + return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) } } // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event") - return rejectionErr + logger.WithError(rejectionErr).WithFields(logrus.Fields{ + "soft_fail": softfail, + "missing_prev": missingPrev, + }).Warn("Stored rejected event") + if rejectionErr != nil { + return commitTransaction, types.RejectedError(rejectionErr.Error()) + } + return commitTransaction, nil } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context + updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -286,7 +354,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return fmt.Errorf("r.updateLatestEvents: %w", err) + return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -298,7 +366,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -317,12 +385,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } - // Update the extremities of the event graph for the room - return nil + // Everything was OK — the latest events updater didn't error and + // we've sent output events. Finally, generate a hook call. + hooks.Run(hooks.KindNewEventPersisted, headered) + return commitTransaction, nil } // fetchAuthEvents will check to see if any of the @@ -334,6 +404,7 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, + updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -347,7 +418,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -384,52 +455,57 @@ func (r *Inputer) fetchAuthEvents( return fmt.Errorf("no servers provided event auth for event ID %q, tried servers %v", event.EventID(), servers) } + // Reuse these to reduce allocations. + authEventNIDs := make([]types.EventNID, 0, 5) + isRejected := false +nextAuthEvent: for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( - res.AuthEvents, + res.AuthEvents.UntrustedEvents(event.RoomVersion), gomatrixserverlib.TopologicalOrderByAuthEvents, ) { // If we already know about this event from the database then we don't // need to store it again or do anything further with it, so just skip // over it rather than wasting cycles. if ev, ok := known[authEvent.EventID()]; ok && ev != nil { - continue + continue nextAuthEvent } - // Check the signatures of the event. - // TODO: It really makes sense for the federation API to be doing this, - // because then it can attempt another server if one serves up an event - // with an invalid signature. For now this will do. + // Check the signatures of the event. If this fails then we'll simply + // skip it, because gomatrixserverlib.Allowed() will notice a problem + // if a critical event is missing anyway. if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { - return fmt.Errorf("event.VerifyEventSignatures: %w", err) + continue nextAuthEvent } // In order to store the new auth event, we need to know its auth chain // as NIDs for the `auth_event_nids` column. Let's see if we can find those. - authEventNIDs := make([]types.EventNID, 0, len(authEvent.AuthEventIDs())) + authEventNIDs = authEventNIDs[:0] for _, eventID := range authEvent.AuthEventIDs() { knownEvent, ok := known[eventID] if !ok { - return fmt.Errorf("missing auth event %s for %s", eventID, authEvent.EventID()) + continue nextAuthEvent } authEventNIDs = append(authEventNIDs, knownEvent.EventNID) } - // Let's take a note of the fact that we now know about this event. - if err := auth.AddEvent(authEvent); err != nil { - return fmt.Errorf("auth.AddEvent: %w", err) - } - // Check if the auth event should be rejected. - isRejected := false - if err := gomatrixserverlib.Allowed(authEvent, auth); err != nil { - isRejected = true + err := gomatrixserverlib.Allowed(authEvent, auth) + if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) + } + + // Let's take a note of the fact that we now know about this event for + // authenticating future events. + if !isRejected { + if err := auth.AddEvent(authEvent); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } } // Now we know about this event, it was stored and the signatures were OK. @@ -444,6 +520,7 @@ func (r *Inputer) fetchAuthEvents( func (r *Inputer) calculateAndSetState( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, @@ -451,30 +528,21 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo) - if input.HasState && !isRejected { - // Check here if we think we're in the room already. + if input.HasState { stateAtEvent.Overwrite = true - var joinEventNIDs []types.EventNID - // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { - // If we have no local users that are joined to the room then any state about - // the room that we have is quite possibly out of date. Therefore in that case - // we should overwrite it rather than merge it. - stateAtEvent.Overwrite = len(joinEventNIDs) == 0 - } // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) + if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return fmt.Errorf("r.DB.AddState: %w", err) + if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + return fmt.Errorf("updater.AddState: %w", err) } } else { stateAtEvent.Overwrite = false @@ -485,7 +553,7 @@ func (r *Inputer) calculateAndSetState( } } - err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 6137941e1..ae28ebefa 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -48,6 +47,7 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, + updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) - } - succeeded := false - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - u := latestEventsUpdater{ ctx: ctx, api: r, @@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } - succeeded = true return } @@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents( type latestEventsUpdater struct { ctx context.Context api *Inputer - updater *shared.LatestEventsUpdater + updater *shared.RoomUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event *gomatrixserverlib.Event @@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the @@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } @@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - return u.api.DB.EventIDs(u.ctx, stateEventNIDs) + return u.updater.EventIDs(u.ctx, stateEventNIDs) } type eventNIDSorter []types.EventNID diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 2511097d0..3953586b2 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -31,7 +31,7 @@ import ( // consumers about the invites added or retired by the change in current state. func (r *Inputer) updateMemberships( ctx context.Context, - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) @@ -48,7 +48,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := r.DB.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 44710962c..fc3be7987 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -10,33 +10,39 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/query" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) +type parsedRespState struct { + AuthEvents []*gomatrixserverlib.Event + StateEvents []*gomatrixserverlib.Event +} + type missingStateReq struct { origin gomatrixserverlib.ServerName - db storage.Database + db *shared.RoomUpdater inputer *Inputer - queryer *query.Queryer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom - servers map[gomatrixserverlib.ServerName]struct{} + servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex - haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEvents map[string]*gomatrixserverlib.Event haveEventsMutex sync.Mutex } // processEventWithMissingState is the entrypoint for a missingStateReq // request, as called from processRoomEvent. +// nolint:gocyclo func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, -) error { +) (*parsedRespState, error) { // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -62,54 +68,180 @@ func (t *missingStateReq) processEventWithMissingState( // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, isGapFilled, err := t.getMissingEvents(ctx, e, roomVersion) + newEvents, isGapFilled, prevStatesKnown, err := t.getMissingEvents(ctx, e, roomVersion) if err != nil { - return fmt.Errorf("t.getMissingEvents: %w", err) + return nil, fmt.Errorf("t.getMissingEvents: %w", err) } if len(newEvents) == 0 { - return fmt.Errorf("expected to find missing events but didn't") + return nil, fmt.Errorf("expected to find missing events but didn't") } if isGapFilled { - logger.Infof("gap filled by /get_missing_events, injecting %d new events", len(newEvents)) + logger.Infof("Gap filled by /get_missing_events, injecting %d new events", len(newEvents)) // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ - Kind: api.KindNew, + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, SendAsServer: api.DoNotSendToOtherServers, }) if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + if _, ok := err.(types.RejectedError); !ok { + return nil, fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) + } + } + } + } + + // If we filled the gap *and* we know the state before the prev events + // then there's nothing else to do, we have everything we need to deal + // with the new event. + if isGapFilled && prevStatesKnown { + logger.Infof("Gap filled and state found for all prev events") + return nil, nil + } + + // Otherwise, if we've reached this point, it's possible that we've + // either not closed the gap, or we did but we still don't seem to + // know the events before the new event. Start by looking up the + // state at the event at the back of the gap and we'll try to roll + // forward the state first. + backwardsExtremity := newEvents[0] + newEvents = newEvents[1:] + + resolvedState, err := t.lookupResolvedStateBeforeEvent(ctx, backwardsExtremity, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (backwards extremity): %w", err) + } + + hadEvents := map[string]bool{} + t.hadEventsMutex.Lock() + for k, v := range t.hadEvents { + hadEvents[k] = v + } + t.hadEventsMutex.Unlock() + + sendOutliers := func(resolvedState *parsedRespState) error { + outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) + if oerr != nil { + return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) + } + var outlierRoomEvents []api.InputRoomEvent + for _, outlier := range outliers { + if hadEvents[outlier.EventID()] { + continue + } + outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ + Kind: api.KindOutlier, + Event: outlier.Headered(roomVersion), + Origin: t.origin, + }) + } + for _, ire := range outlierRoomEvents { + _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + } } } return nil } - backwardsExtremity := newEvents[0] - newEvents = newEvents[1:] + // Send outliers first so we can send the state along with the new backwards + // extremity without any missing auth events. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) + } + // Now send the backward extremity into the roomserver with the + // newly resolved state. This marks the "oldest" point in the backfill and + // sets the baseline state for any new events after this. + stateIDs := make([]string, 0, len(resolvedState.StateEvents)) + for _, event := range resolvedState.StateEvents { + stateIDs = append(stateIDs, event.EventID()) + } + + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + Kind: api.KindOld, + Event: backwardsExtremity.Headered(roomVersion), + Origin: t.origin, + HasState: true, + StateEventIDs: stateIDs, + SendAsServer: api.DoNotSendToOtherServers, + }) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return nil, fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) + } + } + + // Then send all of the newer backfilled events, of which will all be newer + // than the backward extremity, into the roomserver without state. This way + // they will automatically fast-forward based on the room state at the + // extremity in the last step. + for _, newEvent := range newEvents { + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + Kind: api.KindOld, + Event: newEvent.Headered(roomVersion), + Origin: t.origin, + SendAsServer: api.DoNotSendToOtherServers, + }) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return nil, fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) + } + } + } + + // Finally, check again if we know everything we need to know in order to + // make forward progress. If the prev state is known then we consider the + // rolled forward state to be sufficient — we now know all of the state + // before the prev events. If we don't then we need to look up the state + // before the new event as well, otherwise we will never make any progress. + if t.isPrevStateKnown(ctx, e) { + return nil, nil + } + + // If we still haven't got the state for the prev events then we'll go and + // ask the federation for it if needed. + resolvedState, err = t.lookupResolvedStateBeforeEvent(ctx, e, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (new event): %w", err) + } + + // Send the outliers for the retrieved state. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) + } + + // Then return the resolved state, for which the caller can replace the + // HasState with the event IDs to create a new state snapshot when we + // process the new event. + return resolvedState, nil +} + +func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { type respState struct { // A snapshot is considered trustworthy if it came from our own roomserver. // That's because the state will have been through state resolution once // already in QueryStateAfterEvent. trustworthy bool - *gomatrixserverlib.RespState + *parsedRespState } // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. var states []*respState - for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + for _, prevEventID := range e.PrevEventIDs() { // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if lerr != nil { - logger.WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return lerr + prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) + if err != nil { + return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err) } // Append the state onto the collected state. We'll run this through the // state resolution next. @@ -122,114 +254,50 @@ func (t *missingStateReq) processEventWithMissingState( // 1. Ensures that the state is deduplicated fully for each state-key tuple // 2. Ensures that we pick the latest events from both sets, in the case that // one of the prev_events is quite a bit older than the others - resolvedState := &gomatrixserverlib.RespState{} + resolvedState := &parsedRespState{} switch len(states) { case 0: - extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") + extremityIsCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") if !extremityIsCreate { // There are no previous states and this isn't the beginning of the // room - this is an error condition! - logger.Errorf("Failed to lookup any state after prev_events") - return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) + return nil, fmt.Errorf("expected %d states but got %d", len(e.PrevEventIDs()), len(states)) } case 1: // There's only one previous state - if it's trustworthy (came from a // local state snapshot which will already have been through state res), - // use it as-is. There's no point in resolving it again. - if states[0].trustworthy { - resolvedState = states[0].RespState + // use it as-is. There's no point in resolving it again. Only trust a + // trustworthy state snapshot if it actually contains some state for all + // non-create events, otherwise we need to resolve what came from federation. + isCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + if states[0].trustworthy && (isCreate || len(states[0].StateEvents) > 0) { + resolvedState = states[0].parsedRespState break } // Otherwise, if it isn't trustworthy (came from federation), run it through // state resolution anyway for safety, in case there are duplicates. fallthrough default: - respStates := make([]*gomatrixserverlib.RespState, len(states)) + respStates := make([]*parsedRespState, len(states)) for i := range states { - respStates[i] = states[i].RespState + respStates[i] = states[i].parsedRespState } // There's more than one previous state - run them all through state res + var err error t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) t.roomsMu.Unlock(e.RoomID()) if err != nil { - logger.WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err + return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err) } } - hadEvents := map[string]bool{} - t.hadEventsMutex.Lock() - for k, v := range t.hadEvents { - hadEvents[k] = v - } - t.hadEventsMutex.Unlock() - - // Send outliers first so we can send the new backwards extremity without causing errors - outliers, err := resolvedState.Events() - if err != nil { - return err - } - var outlierRoomEvents []api.InputRoomEvent - for _, outlier := range outliers { - if hadEvents[outlier.EventID()] { - continue - } - outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ - Kind: api.KindOutlier, - Event: outlier.Headered(roomVersion), - Origin: t.origin, - }) - } - // TODO: we could do this concurrently? - for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { - return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) - } - } - - // Now send the backward extremity into the roomserver with the - // newly resolved state. This marks the "oldest" point in the backfill and - // sets the baseline state for any new events after this. - stateIDs := make([]string, 0, len(resolvedState.StateEvents)) - for _, event := range resolvedState.StateEvents { - stateIDs = append(stateIDs, event.EventID()) - } - - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ - Kind: api.KindOld, - Event: backwardsExtremity.Headered(roomVersion), - Origin: t.origin, - HasState: true, - StateEventIDs: stateIDs, - SendAsServer: api.DoNotSendToOtherServers, - }) - if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) - } - - // Then send all of the newer backfilled events, of which will all be newer - // than the backward extremity, into the roomserver without state. This way - // they will automatically fast-forward based on the room state at the - // extremity in the last step. - for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ - Kind: api.KindOld, - Event: newEvent.Headered(roomVersion), - Origin: t.origin, - SendAsServer: api.DoNotSendToOtherServers, - }) - if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) - } - } - - return nil + return resolvedState, nil } // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { +func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) if respState != nil { @@ -257,20 +325,20 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion for i := range respState.StateEvents { se := respState.StateEvents[i] if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { - respState.StateEvents[i] = h.Unwrap() + respState.StateEvents[i] = h addedToState = true break } } if !addedToState { - respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + respState.StateEvents = append(respState.StateEvents, h) } } return respState, false, nil } -func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { +func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event { t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() if cached, exists := t.haveEvents[ev.EventID()]; exists { @@ -280,33 +348,50 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g return ev } -func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { - var res api.QueryStateAfterEventsResponse - err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ - RoomID: roomID, - PrevEventIDs: []string{eventID}, - }, &res) - if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) +func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { + var res parsedRespState + roomInfo, err := t.db.RoomInfo(ctx, roomID) + if err != nil { return nil } - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) - for i, ev := range res.StateEvents { + roomState := state.NewStateResolution(t.db, roomInfo) + stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID) + return nil + } + stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, stateAtEvents) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load combined state after %s locally", eventID) + return nil + } + stateEventNIDs := make([]types.EventNID, 0, len(stateEntries)) + for _, entry := range stateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEvents, err := t.db.Events(ctx, stateEventNIDs) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load state events locally") + return nil + } + res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)) + for _, ev := range stateEvents { // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. - stateEvents[i] = t.cacheAndReturn(ev) + res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event)) t.hadEvent(ev.EventID()) } - // we should never access res.StateEvents again so we delete it here to make GC faster - res.StateEvents = nil - var authEvents []*gomatrixserverlib.Event + // encourage GC + stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign + missingAuthEvents := map[string]bool{} + res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3) for _, ev := range stateEvents { t.haveEventsMutex.Lock() for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { - authEvents = append(authEvents, aev.Unwrap()) + res.AuthEvents = append(res.AuthEvents, aev) } else { missingAuthEvents[ae] = true } @@ -320,37 +405,30 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room for evID := range missingAuthEvents { missingEventList = append(missingEventList, evID) } - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } - util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { return nil } - for i, ev := range queryRes.Events { - authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + for i, ev := range events { + res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event)) t.hadEvent(ev.EventID()) } - queryRes.Events = nil } - return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), - AuthEvents: authEvents, - } + return &res } // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what // the server supports. func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( - *gomatrixserverlib.RespState, error) { + *parsedRespState, error) { // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } -func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { +func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { var authEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event for _, state := range states { @@ -369,7 +447,7 @@ retryAllowedState: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) switch err2.(type) { case verifySigError: - return &gomatrixserverlib.RespState{ + return &parsedRespState{ AuthEvents: authEventList, StateEvents: resolvedStateEvents, }, nil @@ -378,14 +456,14 @@ retryAllowedState: default: return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } - util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) - resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID) + resolvedStateEvents = append(resolvedStateEvents, h) goto retryAllowedState default: } return nil, err } - return &gomatrixserverlib.RespState{ + return &parsedRespState{ AuthEvents: authEventList, StateEvents: resolvedStateEvents, }, nil @@ -393,27 +471,18 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events -func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled bool, err error) { +func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) - // query latest events (our trusted forward extremities) - req := api.QueryLatestEventsAndStateRequest{ - RoomID: e.RoomID(), - StateToFetch: needed.Tuples(), - } - var res api.QueryLatestEventsAndStateResponse - if err = t.queryer.QueryLatestEventsAndState(ctx, &req, &res); err != nil { - logger.WithError(err).Warn("Failed to query latest events") - return nil, false, err - } - latestEvents := make([]string, len(res.LatestEvents)) - for i, ev := range res.LatestEvents { - latestEvents[i] = res.LatestEvents[i].EventID + + latest := t.db.LatestEvents() + latestEvents := make([]string, len(latest)) + for i, ev := range latest { + latestEvents[i] = ev.EventID t.hadEvent(ev.EventID) } var missingResp *gomatrixserverlib.RespMissingEvents - for server := range t.servers { + for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, @@ -425,11 +494,11 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve missingResp = &m break } else { - logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) + logger.WithError(err).Warnf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) if errors.Is(err, context.DeadlineExceeded) { select { case <-ctx.Done(): // the parent request context timed out - return nil, false, context.DeadlineExceeded + return nil, false, false, context.DeadlineExceeded default: // this request exceed its own timeout continue } @@ -438,11 +507,11 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } if missingResp == nil { - logger.WithError(err).Errorf( + logger.WithError(err).Warnf( "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, len(t.servers), ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } @@ -450,13 +519,14 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. - logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) - for i, ev := range missingResp.Events { - missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) + missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) + for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { + missingEvents = append(missingEvents, t.cacheAndReturn(ev)) } // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingEvents, gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() hasPrevEvent := false Event: @@ -470,56 +540,88 @@ Event: } if !hasPrevEvent { err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.origin, shouldHaveSomeEventIDs) - logger.WithError(err).Errorf( + logger.WithError(err).Warnf( "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } } if len(newEvents) == 0 { - return nil, false, nil // TODO: error instead? + return nil, false, false, nil // TODO: error instead? } - // now check if we can fill the gap. Look to see if we have state snapshot IDs for the earliest event earliestNewEvent := newEvents[0] - if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 { - if earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { - // we got to the beginning of the room so there will be no state! It's all good we can process this - return newEvents, true, nil - } - // we don't have the state at this earliest event from /g_m_e so we won't have state for later events either - return newEvents, false, nil + + // If we retrieved back to the beginning of the room then there's nothing else + // to do - we closed the gap. + if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + return newEvents, true, t.isPrevStateKnown(ctx, e), nil } - // StateAtEventIDs returned some kind of state for the earliest event so we can fill in the gap! - return newEvents, true, nil + + // If our backward extremity was not a known event to us then we obviously didn't + // close the gap. + if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 && state[0].BeforeStateSnapshotNID == 0 { + return newEvents, false, false, nil + } + + // At this point we are satisfied that we know the state both at the earliest + // retrieved event and at the prev events of the new event. + return newEvents, true, t.isPrevStateKnown(ctx, e), nil } -func (t *missingStateReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - respState *gomatrixserverlib.RespState, err error) { +func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserverlib.Event) bool { + expected := len(e.PrevEventIDs()) + state, err := t.db.StateAtEventIDs(ctx, e.PrevEventIDs()) + if err != nil || len(state) != expected { + // We didn't get as many state snapshots as we expected, or there was an error, + // so we haven't completely solved the problem for the new event. + return false + } + // Check to see if we have a populated state snapshot for all of the prev events. + for _, stateAtEvent := range state { + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // One of the prev events still has unknown state, so we haven't really + // solved the problem. + return false + } + } + return true +} + +func (t *missingStateReq) lookupMissingStateViaState( + ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (respState *parsedRespState, err error) { state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } // Check that the returned state is valid. - if err := state.Check(ctx, t.keys, nil); err != nil { + if err := state.Check(ctx, roomVersion, t.keys, nil); err != nil { return nil, err } + parsedState := &parsedRespState{ + AuthEvents: make([]*gomatrixserverlib.Event, len(state.AuthEvents)), + StateEvents: make([]*gomatrixserverlib.Event, len(state.StateEvents)), + } // Cache the results of this state lookup and deduplicate anything we already // have in the cache, freeing up memory. - for i, ev := range state.AuthEvents { - state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + // We load these as trusted as we called state.Check before which loaded them as untrusted. + for i, evJSON := range state.AuthEvents { + ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) + parsedState.AuthEvents[i] = t.cacheAndReturn(ev) } - for i, ev := range state.StateEvents { - state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + for i, evJSON := range state.StateEvents { + ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) + parsedState.StateEvents[i] = t.cacheAndReturn(ev) } - return &state, nil + return parsedState, nil } func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { + *parsedRespState, error) { util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) @@ -541,27 +643,26 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - // fetch as many as we can from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { + return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil, err - } - for i, ev := range queryRes.Events { - queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + + for i, ev := range events { + events[i].Event = t.cacheAndReturn(events[i].Event) t.hadEvent(ev.EventID()) - evID := queryRes.Events[i].EventID() + evID := events[i].EventID() if missing[evID] { delete(missing, evID) } } - queryRes.Events = nil // allow it to be GCed + + // encourage GC + events = nil // nolint:ineffassign concurrentRequests := 8 missingCount := len(missing) - util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) + util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Debugf("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) // If over 50% of the auth/state events from /state_ids are missing // then we'll just call /state instead, otherwise we'll just end up @@ -573,7 +674,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo "room_id": roomID, "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), - }).Info("Fetching all state at event") + }).Debug("Fetching all state at event") return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) } @@ -585,7 +686,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), "concurrent_requests": concurrentRequests, - }).Info("Fetching missing state at event") + }).Debug("Fetching missing state at event") // Create a queue containing all of the missing event IDs that we want // to retrieve. @@ -611,7 +712,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { - var h *gomatrixserverlib.HeaderedEvent + var h *gomatrixserverlib.Event h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) switch err.(type) { case verifySigError: @@ -622,7 +723,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": missingEventID, "room_id": roomID, - }).Info("Failed to fetch missing event") + }).Warn("Failed to fetch missing event") return } haveEventsMutex.Lock() @@ -651,29 +752,30 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo return resp, err } -func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( - *gomatrixserverlib.RespState, error) { // nolint:unparam +func (t *missingStateReq) createRespStateFromStateIDs( + stateIDs gomatrixserverlib.RespStateIDs, +) (*parsedRespState, error) { // nolint:unparam t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() // create a RespState response using the response to /state_ids as a guide - respState := gomatrixserverlib.RespState{} + respState := parsedRespState{} for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok { - logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) continue } - respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) + respState.StateEvents = append(respState.StateEvents, ev) } for i := range stateIDs.AuthEventIDs { ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] if !ok { - logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) continue } - respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) + respState.AuthEvents = append(respState.AuthEvents, ev) } // We purposefully do not do auth checks on the returned events, as they will still // be processed in the exact same way, just as a 'rejected' event @@ -681,22 +783,19 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib return &respState, nil } -func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { if localFirst { // fetch from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: []string{missingEventID}, - } - var queryRes api.QueryEventsByIDResponse - if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + if err != nil { util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) - } else if len(queryRes.Events) == 1 { - return queryRes.Events[0], nil + } else if len(events) == 1 { + return events[0].Event, nil } } var event *gomatrixserverlib.Event found := false - for serverName := range t.servers { + for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) @@ -714,7 +813,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) if err != nil { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") continue } found = true @@ -725,10 +824,10 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } - return t.cacheAndReturn(event.Headered(roomVersion)), nil + return t.cacheAndReturn(event), nil } func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go new file mode 100644 index 000000000..4fa966281 --- /dev/null +++ b/roomserver/internal/input/input_test.go @@ -0,0 +1,93 @@ +package input_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +func psqlConnectionString() config.DataSource { + user := os.Getenv("POSTGRES_USER") + if user == "" { + user = "dendrite" + } + dbName := os.Getenv("POSTGRES_DB") + if dbName == "" { + dbName = "dendrite" + } + connStr := fmt.Sprintf( + "user=%s dbname=%s sslmode=disable", user, dbName, + ) + password := os.Getenv("POSTGRES_PASSWORD") + if password != "" { + connStr += fmt.Sprintf(" password=%s", password) + } + host := os.Getenv("POSTGRES_HOST") + if host != "" { + connStr += fmt.Sprintf(" host=%s", host) + } + return config.DataSource(connStr) +} + +func TestSingleTransactionOnInput(t *testing.T) { + deadline, _ := t.Deadline() + if max := time.Now().Add(time.Second * 3); deadline.After(max) { + deadline = max + } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + event, err := gomatrixserverlib.NewEventFromTrustedJSON( + []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), + false, gomatrixserverlib.RoomVersionV6, + ) + if err != nil { + t.Fatal(err) + } + in := api.InputRoomEvent{ + Kind: api.KindOutlier, // don't panic if we generate an output event + Event: event.Headered(gomatrixserverlib.RoomVersionV6), + } + cache, err := caching.NewInMemoryLRUCache(false) + if err != nil { + t.Fatal(err) + } + db, err := storage.Open( + &config.DatabaseOptions{ + ConnectionString: psqlConnectionString(), + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + cache, + ) + if err != nil { + t.Logf("PostgreSQL not available (%s), skipping", err) + t.SkipNow() + } + inputter := &input.Inputer{ + DB: db, + } + res := &api.InputRoomEventsResponse{} + inputter.InputRoomEvents( + ctx, + &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{in}, + Asynchronous: false, + }, + res, + ) + // If we fail here then it's because we've hit the test deadline, + // so we probably deadlocked + if err := res.Err(); err != nil { + t.Fatal(err) + } +} diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 85b2322fe..6559cd081 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite( return nil, fmt.Errorf("failed to load RoomInfo: %w", err) } - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": roomID, - "room_version": req.RoomVersion, - "target_user_id": targetUserID, - "room_info_exists": info != nil, - }).Info("processing invite event") - _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) isTargetLocal := domain == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName + logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ + "inviter": event.Sender(), + "invitee": *event.StateKey(), + "room_id": roomID, + "event_id": event.EventID(), + }) + logger.WithFields(log.Fields{ + "room_version": req.RoomVersion, + "room_info_exists": info != nil, + "target_local": isTargetLocal, + "origin_local": isOriginLocal, + }).Debug("processing invite event") + inviteState := req.InviteRoomState if len(inviteState) == 0 && info != nil { var is []gomatrixserverlib.InviteV2StrippedState @@ -122,75 +128,17 @@ func (r *Inviter) PerformInvite( Code: api.PerformErrorNotAllowed, Msg: "User is already joined to room", } + logger.Debugf("user already joined") return nil, nil } - if isOriginLocal { - // The invite originated locally. Therefore we have a responsibility to - // try and see if the user is allowed to make this invite. We can't do - // this for invites coming in over federation - we have to take those on - // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) - if err != nil { - log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( - "processInviteEvent.checkAuthEvents failed for event", - ) - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - } - - // If the invite originated from us and the target isn't local then we - // should try and send the invite over federation first. It might be - // that the remote user doesn't exist, in which case we can give up - // processing here. - if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { - fsReq := &federationAPI.PerformInviteRequest{ - RoomVersion: req.RoomVersion, - Event: event, - InviteRoomState: inviteState, - } - fsRes := &federationAPI.PerformInviteResponse{} - if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") - return nil, nil - } - event = fsRes.Event - } - - // Send the invite event to the roomserver input stream. This will - // notify existing users in the room about the invite, update the - // membership table and ensure that the event is ready and available - // to use as an auth event when accepting the invite. - inputReq := &api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{ - { - Kind: api.KindNew, - Event: event, - Origin: event.Origin(), - SendAsServer: req.SendAsServer, - }, - }, - } - inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) - if err = inputRes.Err(); err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") - return nil, nil - } - } else { + if !isOriginLocal { // The invite originated over federation. Process the membership // update, which will notify the sync API etc about the incoming - // invite. + // invite. We do NOT send an InputRoomEvent for the invite as it + // will never pass auth checks due to lacking room state, but we + // still need to tell the client about the invite so we can accept + // it, hence we return an output event to send to the sync api. updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) @@ -205,10 +153,77 @@ func (r *Inviter) PerformInvite( if err = updater.Commit(); err != nil { return nil, fmt.Errorf("updater.Commit: %w", err) } - + logger.Debugf("updated membership to invite and sending invite OutputEvent") return outputUpdates, nil } + // The invite originated locally. Therefore we have a responsibility to + // try and see if the user is allowed to make this invite. We can't do + // this for invites coming in over federation - we have to take those on + // trust. + _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + if err != nil { + logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( + "processInviteEvent.checkAuthEvents failed for event", + ) + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + return nil, nil + } + + // If the invite originated from us and the target isn't local then we + // should try and send the invite over federation first. It might be + // that the remote user doesn't exist, in which case we can give up + // processing here. + if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { + fsReq := &federationAPI.PerformInviteRequest{ + RoomVersion: req.RoomVersion, + Event: event, + InviteRoomState: inviteState, + } + fsRes := &federationAPI.PerformInviteResponse{} + if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") + return nil, nil + } + event = fsRes.Event + logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) + } + + // Send the invite event to the roomserver input stream. This will + // notify existing users in the room about the invite, update the + // membership table and ensure that the event is ready and available + // to use as an auth event when accepting the invite. + // It will NOT notify the invitee of this invite. + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event, + Origin: event.Origin(), + SendAsServer: req.SendAsServer, + }, + }, + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = inputRes.Err(); err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return nil, nil + } + + // Don't notify the sync api of this event in the same way as a federated invite so the invitee + // gets the invite, as the roomserver will do this when it processes the m.room.member invite. return nil, nil } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index a1ffab5dd..9d2a66d4c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -51,8 +51,15 @@ func (r *Joiner) PerformJoin( req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, ) { - roomID, joinedVia, err := r.performJoin(ctx, req) + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomIDOrAlias, + "user_id": req.UserID, + "servers": req.ServerNames, + }) + logger.Info("User requested to room join") + roomID, joinedVia, err := r.performJoin(context.Background(), req) if err != nil { + logger.WithError(err).Error("Failed to join room") sentry.CaptureException(err) perr, ok := err.(*rsAPI.PerformError) if ok { @@ -62,7 +69,9 @@ func (r *Joiner) PerformJoin( Msg: err.Error(), } } + return } + logger.Info("User joined room successfully") res.RoomID = roomID res.JoinedVia = joinedVia } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index eac528eaf..49ddd4810 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -16,24 +16,29 @@ package perform import ( "context" + "encoding/json" "fmt" "strings" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type Leaver struct { - Cfg *config.RoomServer - DB storage.Database - FSAPI fsAPI.FederationInternalAPI - + Cfg *config.RoomServer + DB storage.Database + FSAPI fsAPI.FederationInternalAPI + UserAPI userapi.UserInternalAPI Inputer *input.Inputer } @@ -50,8 +55,19 @@ func (r *Leaver) PerformLeave( if domain != r.Cfg.Matrix.ServerName { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomID, + "user_id": req.UserID, + }) + logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { - return r.performLeaveRoomByID(ctx, req, res) + output, err := r.performLeaveRoomByID(context.Background(), req, res) + if err != nil { + logger.WithError(err).Error("Failed to leave room") + } else { + logger.Info("User left room successfully") + } + return output, err } return nil, fmt.Errorf("room ID %q is invalid", req.RoomID) } @@ -73,6 +89,31 @@ func (r *Leaver) performLeaveRoomByID( if host != r.Cfg.Matrix.ServerName { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } + // check that this is not a "server notice room" + accData := &userapi.QueryAccountDataResponse{} + if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ + UserID: req.UserID, + RoomID: req.RoomID, + DataType: "m.tag", + }, accData); err != nil { + return nil, fmt.Errorf("unable to query account data") + } + + if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { + tagData, ok := roomData["m.tag"] + if ok { + tags := gomatrix.TagContent{} + if err = json.Unmarshal(tagData, &tags); err != nil { + return nil, fmt.Errorf("unable to unmarshal tag content") + } + if _, ok = tags.Tags["m.server_notice"]; ok { + // mimic the returned values from Synapse + res.Message = "You cannot reject this invite" + res.Code = 403 + return nil, fmt.Errorf("You cannot reject this invite") + } + } + } } // There's no invite pending, so first of all we want to find out diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6b4cb5816..c8bbe7705 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -125,38 +125,6 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI -func (r *Queryer) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info == nil { - return errors.New("room doesn't exist") - } - - response.RoomExists = !info.IsStub - response.RoomVersion = info.RoomVersion - - for _, authEventID := range request.AuthEventIDs { - if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID) - } - } - - for _, prevEventID := range request.PrevEventIDs { - if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err != nil || len(state) == 0 { - response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) - } - } - - return nil -} - // QueryEventsByID implements api.RoomserverInternalAPI func (r *Queryer) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 4f6a58bde..99c596606 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -11,6 +11,8 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" ) @@ -40,7 +42,6 @@ const ( // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" - RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents" RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" @@ -91,6 +92,10 @@ func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.Federation func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { } +// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario +func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { +} + // SetRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverInternalAPI) SetRoomAlias( ctx context.Context, @@ -302,19 +307,6 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // QueryEventsByID implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index bf319262f..691a45830 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -149,20 +149,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - RoomserverQueryMissingAuthPrevEventsPath, - httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingAuthPrevEventsRequest - var response api.QueryMissingAuthPrevEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 669957be1..950c6b4e7 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -50,10 +50,10 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return internal.NewRoomserverAPI( - cfg, roomserverDB, js, + base.ProcessContext, cfg, roomserverDB, js, cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent), cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), base.Caches, perspectiveServerNames, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 15d592b46..e5f69521e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,6 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -30,13 +29,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type StateResolutionStorage interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) +} + type StateResolution struct { - db storage.Database + db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 15764366b..a9851e05b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -86,11 +86,10 @@ type Database interface { // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) - // Look up the latest events in a room in preparation for an update. - // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. - // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // Opens and returns a room updater, which locks the room and opens a transaction. + // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go index 06740dc8b..06442a4c3 100644 --- a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go @@ -256,23 +256,17 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { return fmt.Errorf("assertion query failed: %s", err) } if count > 0 { - var debugEventID, debugRoomID string - var debugEventTypeNID, debugStateKeyNID, debugSnapNID, debugDepth int64 - err = tx.QueryRow( - `SELECT event_id, event_type_nid, event_state_key_nid, roomserver_events.state_snapshot_nid, depth, room_id FROM roomserver_events - JOIN roomserver_rooms ON roomserver_rooms.room_nid = roomserver_events.room_nid WHERE roomserver_events.state_snapshot_nid < $1 AND roomserver_events.state_snapshot_nid != 0`, maxsnapshotid, - ).Scan(&debugEventID, &debugEventTypeNID, &debugStateKeyNID, &debugSnapNID, &debugDepth, &debugRoomID) - if err != nil { - logrus.Errorf("cannot extract debug info: %v", err) - } else { - logrus.Errorf( - "Affected row: event_id=%v room_id=%v type=%v state_key=%v snapshot=%v depth=%v", - debugEventID, debugRoomID, debugEventTypeNID, debugStateKeyNID, debugSnapNID, debugDepth, - ) - logrus.Errorf("To fix this manually, run this query first then retry the migration: "+ - "UPDATE roomserver_events SET state_snapshot_nid=0 WHERE event_id='%v'", debugEventID) + var res sql.Result + var c int64 + res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to reset invalid state snapshots: %w", err) + } + if c, err = res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get row count for invalid state snapshots updated: %w", err) + } else if c != count { + return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) } - return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) } if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { return fmt.Errorf("assertion query failed: %s", err) diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 32e457821..b3220effd 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -76,14 +76,16 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + stmt := sqlutil.TxStmt(txn, s.insertEventJSONStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 3a7cf03e3..762b3a1fc 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt) + rows, err := stmt.QueryContext( ctx, pq.StringArray(eventStateKeys), ) if err != nil { @@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt) + rows, err := stmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index e558072a5..1d5de5822 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 778cd8d73..8012174a0 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -71,10 +71,10 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( ` const insertEventSQL = "" + - "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + - " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + - " DO NOTHING" + + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -147,6 +150,7 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt + bulkSelectUnsentEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt } @@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, }.Prepare(db) @@ -192,7 +197,8 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) + err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, @@ -212,9 +218,10 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -254,13 +261,14 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByNID lookups a list of state events by event NID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) if err != nil { return nil, err } @@ -291,9 +299,10 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -428,8 +437,9 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -453,10 +463,29 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ return results, nil } +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { + var stmt *sql.Stmt + if onlyUnsent { + stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + } + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -484,9 +513,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { - rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 344302c8f..176c16e48 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) @@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired( // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b0d906c80..48c2c35cd 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { } func (s *membershipStatements) InsertMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) @@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership( } func (s *membershipStatements) SelectMembershipForUpdate( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, @@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, localOnly bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { @@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { stmt = s.selectMembershipsFromRoomStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return @@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows @@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } func (s *membershipStatements) UpdateMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( @@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms( + ctx context.Context, txn *sql.Tx, + roomNIDs []types.RoomNID, +) (map[types.EventStateKeyNID]int, error) { roomIDarray := make([]int64, len(roomNIDs)) for i := range roomNIDs { roomIDarray[i] = int64(roomNIDs[i]) } - rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) if err != nil { return nil, err } @@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers( + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, searchString string, limit int, +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - forget bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( ctx, roomNID, targetUserNID, forget, @@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, +) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, +) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 8deb68441..15985fcd6 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index 031825fee..d13df8e7f 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) ([]string, error) { - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err } @@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f51eba4d4..b2685084d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDs pq.Int64Array - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, ) if err == sql.ErrNoRows { @@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs( ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - stmt := s.selectLatestEventNIDsStmt + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { - rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) if err != nil { return nil, err } @@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { var array pq.Int64Array for _, nid := range roomNIDs { array = append(array, int64(nid)) } - rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } @@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { var array pq.StringArray for _, roomID := range roomIDs { array = append(array, roomID) } - rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 27d85e83b..6f8f9e1b5 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData( for _, e := range entries { nids = append(nids, e.EventNID) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), eventNIDsAsArray(nids), ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) + rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 4fc0fa48a..ce9f24636 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go deleted file mode 100644 index 36865081a..000000000 --- a/roomserver/storage/shared/latest_events_updater.go +++ /dev/null @@ -1,133 +0,0 @@ -package shared - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type LatestEventsUpdater struct { - transaction - d *Database - roomInfo types.RoomInfo - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -func rollback(txn *sql.Tx) { - if txn == nil { - return - } - txn.Rollback() // nolint: errcheck -} - -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) - if err != nil { - rollback(txn) - return nil, err - } - stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - rollback(txn) - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - rollback(txn) - return nil, err - } - } - return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - return u.roomInfo.RoomVersion -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { - return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) - } - if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { - roomInfo.StateSnapshotNID = currentStateSnapshotNID - roomInfo.IsStub = false - u.d.Cache.StoreRoomInfo(roomID, roomInfo) - } - } - return nil - }) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) - }) -} - -func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) -} diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f1f589a31..8f3f3d631 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go new file mode 100644 index 000000000..810a18ef2 --- /dev/null +++ b/roomserver/storage/shared/room_updater.go @@ -0,0 +1,303 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type RoomUpdater struct { + transaction + d *Database + roomInfo *types.RoomInfo + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID + roomExists bool +} + +func rollback(txn *sql.Tx) { + if txn == nil { + return + } + txn.Rollback() // nolint: errcheck +} + +func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) { + // If the roomInfo is nil then that means that the room doesn't exist + // yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that + // would involve locking a row on the table that doesn't exist. Instead + // we will just run with a normal database transaction. It'll either + // succeed, processing a create event which creates the room, or it won't. + if roomInfo == nil { + return &RoomUpdater{ + transaction{ctx, txn}, d, nil, nil, "", 0, false, + }, nil + } + + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) + if err != nil { + rollback(txn) + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + rollback(txn) + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + rollback(txn) + return nil, err + } + } + return &RoomUpdater{ + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, true, + }, nil +} + +// RoomExists returns true if the room exists and false otherwise. +func (u *RoomUpdater) RoomExists() bool { + return u.roomExists +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Commit() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Rollback() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Rollback() +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + return u.roomInfo.RoomVersion +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +func (u *RoomUpdater) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + for _, authEventID := range e.AuthEventIDs() { + if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer +func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } + } + return nil + }) +} + +func (u *RoomUpdater) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + return u.d.events(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) +} + +func (u *RoomUpdater) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) +} + +func (u *RoomUpdater) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs) +} + +func (u *RoomUpdater) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return u.d.stateEntries(ctx, u.txn, stateBlockNIDs) +} + +func (u *RoomUpdater) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples) +} + +func (u *RoomUpdater) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state) +} + +func (u *RoomUpdater) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID) + }) +} + +func (u *RoomUpdater) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return u.d.eventTypeNIDs(ctx, u.txn, eventTypes) +} + +func (u *RoomUpdater) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys) +} + +func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return u.d.roomInfo(ctx, u.txn, roomID) +} + +func (u *RoomUpdater) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) +} + +func (u *RoomUpdater) UnsentEventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) +} + +func (u *RoomUpdater) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + +func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) +} + +func (u *RoomUpdater) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { + if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { + roomInfo.StateSnapshotNID = currentStateSnapshotNID + roomInfo.IsStub = false + u.d.Cache.StoreRoomInfo(roomID, roomInfo) + } + } + return nil + }) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) +} + +func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d4c5ebb5b..b255cfb3f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -26,23 +26,23 @@ import ( const redactionsArePermanent = true type Database struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - RedactionsTable tables.Redactions - GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + RedactionsTable tables.Redactions + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -51,25 +51,20 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.eventTypeNIDs(ctx, nil, eventTypes) +} + +func (d *Database) eventTypeNIDs( + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - remaining := []string{} - for _, eventType := range eventTypes { - if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - result[eventType] = nid - } else { - remaining = append(remaining, eventType) - } + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) - if err != nil { - return nil, err - } - for eventType, nid := range nids { - result[eventType] = nid - d.Cache.StoreRoomServerEventTypeNID(eventType, nid) - } + for eventType, nid := range nids { + result[eventType] = nid } return result, nil } @@ -77,30 +72,25 @@ func (d *Database) EventTypeNIDs( func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) +} + +func (d *Database) eventStateKeyNIDs( + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) - remaining := []string{} - for _, eventStateKey := range eventStateKeys { - if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - result[eventStateKey] = nid - } else { - remaining = append(remaining, eventStateKey) - } + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) - if err != nil { - return nil, err - } - for eventStateKey, nid := range nids { - result[eventStateKey] = nid - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid) - } + for eventStateKey, nid := range nids { + result[eventStateKey] = nid } return result, nil } @@ -108,23 +98,31 @@ func (d *Database) EventStateKeyNIDs( func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) } func (d *Database) StateEntriesForTuples( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples) +} + +func (d *Database) stateEntriesForTuples( + ctx context.Context, txn *sql.Tx, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := []types.StateEntryList{} for i, entry := range entries { - entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -137,10 +135,14 @@ func (d *Database) StateEntriesForTuples( } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return d.roomInfo(ctx, nil, roomID) +} + +func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { return &roomInfo, nil } - roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) + roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) if err == nil && roomInfo != nil { d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomInfo(roomID, *roomInfo) @@ -153,13 +155,22 @@ func (d *Database) AddState( roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return d.addState(ctx, nil, roomNID, stateBlockNIDs, state) +} + +func (d *Database) addState( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { if len(stateBlockNIDs) > 0 && len(state) > 0 { // Check to see if the event already appears in any of the existing state // blocks. If it does then we should not add it again, as this will just // result in excess state blocks and snapshots. // TODO: Investigate why this is happening - probably input_events.go! - blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) if berr != nil { return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) } @@ -180,7 +191,7 @@ func (d *Database) AddState( } } } - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { if len(state) > 0 { // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID @@ -205,7 +216,27 @@ func (d *Database) AddState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs, NoFilter) +} + +type UnsentFilter bool + +const ( + NoFilter UnsentFilter = false + FilterUnsentOnly UnsentFilter = true +) + +func (d *Database) eventNIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, +) (map[string]types.EventNID, error) { + switch filter { + case FilterUnsentOnly: + return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) + case NoFilter: + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + default: + panic("impossible case") + } } func (d *Database) SetState( @@ -219,24 +250,34 @@ func (d *Database) SetState( func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { - _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return d.snapshotNIDFromEventID(ctx, nil, eventID) +} + +func (d *Database) snapshotNIDFromEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) return stateNID, err } func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { - return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) +} + +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err } @@ -246,7 +287,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type nids = append(nids, nid) } - return d.Events(ctx, nids) + return d.events(ctx, txn, nids) } func (d *Database) LatestEventIDs( @@ -271,21 +312,33 @@ func (d *Database) LatestEventIDs( func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) + return d.stateBlockNIDs(ctx, nil, stateNIDs) +} + +func (d *Database) stateBlockNIDs( + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs) } func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.stateEntries(ctx, nil, stateBlockNIDs) +} + +func (d *Database) stateEntries( + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := make([]types.StateEntryList, 0, len(entries)) for i, entry := range entries { - eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -304,17 +357,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias) } func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID) } func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { - return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias) } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { @@ -335,7 +388,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, + ctx, nil, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room @@ -349,14 +402,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly) +} + +func (d *Database) getMembershipEventNIDsForRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.MembershipTable.SelectMembershipsFromRoomAndMembership( - ctx, roomNID, tables.MembershipStateJoin, localOnly, + ctx, txn, roomNID, tables.MembershipStateJoin, localOnly, ) } - return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) + return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly) } func (d *Database) GetInvitesForUser( @@ -364,22 +423,28 @@ func (d *Database) GetInvitesForUser( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { - return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + return d.events(ctx, nil, eventNIDs) +} + +func (d *Database) events( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) + roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) if err != nil { return nil, err } @@ -398,7 +463,7 @@ func (d *Database) Events( } fetchNIDList = append(fetchNIDList, n) } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) + dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) if err != nil { return nil, err } @@ -440,19 +505,19 @@ func (d *Database) MembershipUpdater( return updater, err } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*LatestEventsUpdater, error) { - if d.GetLatestEventsForUpdateFn != nil { - return d.GetLatestEventsForUpdateFn(ctx, roomInfo) +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*RoomUpdater, error) { + if d.GetRoomUpdaterFn != nil { + return d.GetRoomUpdaterFn(ctx, roomInfo) } txn, err := d.DB.Begin() if err != nil { return nil, err } - var updater *LatestEventsUpdater + var updater *RoomUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) + updater, err = NewRoomUpdater(ctx, d, txn, roomInfo) return err }) return updater, err @@ -461,6 +526,13 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +} + +func (d *Database) storeEvent( + ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -472,8 +544,11 @@ func (d *Database) StoreEvent( redactedEventID string err error ) - - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var txn *sql.Tx + if updater != nil && updater.txn != nil { + txn = updater.txn + } + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { // TODO: Here we should aim to have two different code paths for new rooms // vs existing ones. @@ -520,6 +595,8 @@ func (d *Database) StoreEvent( if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } else if err != nil { + return fmt.Errorf("d.EventsTable.InsertEvent: %w", err) } if err != nil { return fmt.Errorf("d.EventsTable.SelectEvent: %w", err) @@ -546,42 +623,32 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - var roomInfo *types.RoomInfo - var updater *LatestEventsUpdater if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. - updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) - } - // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents - // and EndTransaction in a writer then it's possible for a new write txn to be made between the two - // function calls which will then fail with 'database is locked'. This new write txn would HAVE to be - // something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to - // SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases - // as they don't go via InputRoomEvents - err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error { - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return fmt.Errorf("updater.StorePreviousEvents: %w", err) + succeeded := false + if updater == nil { + var roomInfo *types.RoomInfo + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } - succeeded := true - err = sqlutil.EndTransaction(updater, &succeeded) - return err - }) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", err + if roomInfo == nil && len(prevEvents) > 0 { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) } + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + succeeded = true } return eventNID, roomNID, types.StateAtEvent{ @@ -603,7 +670,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, true) + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } func (d *Database) assignRoomNID( @@ -629,9 +696,6 @@ func (d *Database) assignRoomNID( func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { - if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - return eventTypeNID, nil - } // Check if we already have a numeric ID in the database. eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { @@ -642,18 +706,12 @@ func (d *Database) assignEventTypeNID( eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) } } - if err == nil { - d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID) - } return eventTypeNID, err } func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - return eventStateKeyNID, nil - } // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { @@ -664,9 +722,6 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } - if err == nil { - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID) - } return eventStateKeyNID, err } @@ -875,14 +930,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s eventNIDs = append(eventNIDs, e.EventNID) } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } // return the event requested for _, e := range entries { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { - data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) if err != nil { return nil, err } @@ -922,11 +977,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } - roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } - roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) } @@ -945,7 +1000,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -965,7 +1020,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -999,11 +1054,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) } @@ -1027,11 +1082,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) if err != nil { return nil, err } @@ -1041,7 +1096,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) if err != nil { return nil, err } @@ -1057,12 +1112,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) + return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { - return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) + return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } // GetKnownUsers searches all users that userID knows about. @@ -1071,17 +1126,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin if err != nil { return nil, err } - return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) + return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDs(ctx) + return d.RoomsTable.SelectRoomIDs(ctx, nil) } // ForgetRoom sets a users room to forgotten func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID}) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go index 8d0331748..8f5ab8fc5 100644 --- a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go @@ -179,7 +179,17 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { return fmt.Errorf("assertion query failed: %s", err) } if count > 0 { - return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + var res sql.Result + var c int64 + res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to reset invalid state snapshots: %w", err) + } + if c, err = res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get row count for invalid state snapshots updated: %w", err) + } else if c != count { + return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) + } } if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { return fmt.Errorf("assertion query failed: %s", err) diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 53b219294..f470ea326 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 62fbce2d0..bf12d5b83 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + } if err != nil { return nil, err } @@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 22df3fb22..f2c9c42fe 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + rows, err := stmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 7483e2815..969a10ce5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -49,7 +49,8 @@ const eventsSchema = ` const insertEventSQL = ` INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING + ON CONFLICT DO UPDATE + SET is_rejected = $8 WHERE is_rejected = 0 RETURNING event_nid, state_snapshot_nid; ` @@ -98,6 +99,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -117,8 +121,9 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //bulkSelectEventNIDStmt *sql.Stmt + //bulkSelectUnsentEventNIDStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -143,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, - {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, }.Prepare(db) } @@ -184,7 +190,7 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -196,6 +202,7 @@ func (s *eventStatements) BulkSelectStateEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) @@ -235,7 +242,7 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) @@ -263,6 +270,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( if err != nil { return nil, fmt.Errorf("s.db.Prepare: %w", err) } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, params...) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -291,7 +299,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -303,6 +311,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -381,6 +390,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( if err != nil { return nil, err } + selectPrep = sqlutil.TxStmt(txn, selectPrep) ////////////// rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) @@ -454,7 +464,7 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { @@ -465,6 +475,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) @@ -488,19 +499,38 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ return results, nil } +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + var selectOrig string + if onlyUnsent { + selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } else { + selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -538,13 +568,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iEventNIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventNIDs[i] = v diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index c1d7347ae..d54d313a9 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) @@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired( // selectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 2e58431d3..181b4b4c9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt @@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { selectStmt = s.selectMembershipsFromRoomStmt } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + } if err != nil { return nil, err } @@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( @@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index b07c0ac42..9e416ace3 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 323945b88..7c7bead95 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return } @@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index c441daec0..5413475e2 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { return roomIDs, nil } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, ) if err != nil { @@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v @@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } if err != nil { return nil, err } @@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i, v := range roomIDs { iRoomIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 58b0b5dc2..d51fc492d 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData( if err != nil { return 0, fmt.Errorf("json.Marshal: %w", err) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), js, ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { intfs := make([]interface{}, len(stateBlockNIDs)) for i := range stateBlockNIDs { @@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 040d99ae6..3c4bde3f5 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { @@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 1fcc7989d..325c253b5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, + GetRoomUpdaterFn: d.GetRoomUpdater, } return nil } @@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return false } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*shared.LatestEventsUpdater, error) { +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*shared.RoomUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) + return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 6ad7ed2e8..e3fed700b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -18,20 +18,20 @@ type EventJSONPair struct { type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error - BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) + BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error) } type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) - BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error) } type EventStateKeys interface { InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) - BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } type Events interface { @@ -42,12 +42,12 @@ type Events interface { SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError - BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) + BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. - BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error @@ -55,12 +55,13 @@ type Events interface { BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. - BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) - SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) + SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } type Rooms interface { @@ -69,29 +70,29 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) - SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) - SelectRoomIDs(ctx context.Context) ([]string, error) - BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) - BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) + SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) + SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) - BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) - SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) - SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) - SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) } @@ -106,7 +107,7 @@ type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. - SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) + SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) } type MembershipState int64 @@ -121,24 +122,24 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) - SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) - SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) + SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error - SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) - SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error - SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) } type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) - SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) + SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) } type RedactionInfo struct { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index d7e03ad61..5d52ccfcd 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -83,6 +83,10 @@ type StateKeyTuple struct { EventStateKeyNID EventStateKeyNID } +func (a StateKeyTuple) IsCreate() bool { + return a.EventTypeNID == MRoomCreateNID && a.EventStateKeyNID == EmptyStateKeyNID +} + // LessThan returns true if this state key is less than the other state key. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { @@ -209,6 +213,12 @@ type MissingEventError string func (e MissingEventError) Error() string { return string(e) } +// A RejectedError is returned when an event is stored as rejected. The error +// contains the reason why. +type RejectedError string + +func (e RejectedError) Error() string { return string(e) } + // RoomInfo contains metadata about a room type RoomInfo struct { RoomNID RoomNID diff --git a/setup/base/base.go b/setup/base/base.go index 819fe1ad4..e39977541 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -38,7 +38,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" @@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. -func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) +func (b *BaseDendrite) CreateAccountsDB() userdb.Database { + db, err := userdb.NewDatabase( + &b.Cfg.UserAPI.AccountDatabase, + b.Cfg.Global.ServerName, + b.Cfg.UserAPI.BCryptCost, + b.Cfg.UserAPI.OpenIDTokenLifetimeMS, + userapi.DefaultLoginTokenLifetime, + ) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 75f5e3df3..4590e752b 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -18,6 +18,10 @@ type ClientAPI struct { // If set, allows registration by anyone who also has the shared // secret, even if registration is otherwise disabled. RegistrationSharedSecret string `yaml:"registration_shared_secret"` + // If set, prevents guest accounts from being created. Only takes + // effect if registration is enabled, otherwise guests registration + // is forbidden either way. + GuestsDisabled bool `yaml:"guests_disabled"` // Boolean stating whether catpcha registration is enabled // and required diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 4f5f49de8..95e705033 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -29,8 +29,6 @@ type FederationAPI struct { // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` - Proxy Proxy `yaml:"proxy_outbound"` - // Perspective keyservers, to use as a backup when direct key fetch // requests don't succeed KeyPerspectives KeyPerspectives `yaml:"key_perspectives"` @@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) { c.FederationMaxRetries = 16 c.DisableTLSValidation = false - - c.Proxy.Defaults() } func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 6f2306a6d..b947f2076 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -57,6 +57,9 @@ type Global struct { // DNS caching options for all outbound HTTP requests DNSCache DNSCacheOptions `yaml:"dns_cache"` + + // ServerNotices configuration used for sending server notices + ServerNotices ServerNotices `yaml:"server_notices"` } func (c *Global) Defaults(generate bool) { @@ -72,6 +75,7 @@ func (c *Global) Defaults(generate bool) { c.Metrics.Defaults(generate) c.DNSCache.Defaults() c.Sentry.Defaults() + c.ServerNotices.Defaults(generate) } func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -82,6 +86,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith) + c.ServerNotices.Verify(configErrs, isMonolith) } type OldVerifyKeys struct { @@ -123,6 +128,31 @@ func (c *Metrics) Defaults(generate bool) { func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { } +// ServerNotices defines the configuration used for sending server notices +type ServerNotices struct { + Enabled bool `yaml:"enabled"` + // The localpart to be used when sending notices + LocalPart string `yaml:"local_part"` + // The displayname to be used when sending notices + DisplayName string `yaml:"display_name"` + // The avatar of this user + AvatarURL string `yaml:"avatar"` + // The roomname to be used when creating messages + RoomName string `yaml:"room_name"` +} + +func (c *ServerNotices) Defaults(generate bool) { + if generate { + c.Enabled = true + c.LocalPart = "_server" + c.DisplayName = "Server Alert" + c.RoomName = "Server Alert" + c.AvatarURL = "" + } +} + +func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} + // The configuration to use for Sentry error reporting type Sentry struct { Enabled bool `yaml:"enabled"` diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index 94e2d88b3..9271cd8b4 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -2,8 +2,6 @@ package config import ( "fmt" - - "github.com/nats-io/nats.go" ) type JetStream struct { @@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string { return fmt.Sprintf("%s%s", c.TopicPrefix, name) } -func (c *JetStream) Durable(name string) nats.SubOpt { - return nats.Durable(c.TopicFor(name)) +func (c *JetStream) Durable(name string) string { + return c.TopicFor(name) } func (c *JetStream) Defaults(generate bool) { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 5aa54929e..8f7611f0a 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -58,6 +58,11 @@ global: basic_auth: username: metrics password: metrics + server_notices: + local_part: "_server" + display_name: "Server alerts" + avatar: "" + room_name: "Server Alerts" app_service_api: internal_api: listen: http://localhost:7777 @@ -118,11 +123,6 @@ federation_sender: conn_max_lifetime: -1 send_max_retries: 16 disable_tls_validation: false - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 key_server: internal_api: listen: http://localhost:7779 diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index b2cde2e96..1cb5eba18 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -16,9 +16,6 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` - // The Device database stores session information for the devices of logged - // in local users. It is accessed by the UserAPI. - DeviceDatabase DatabaseOptions `yaml:"device_database"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes @@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) { c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781" c.AccountDatabase.Defaults(10) - c.DeviceDatabase.Defaults(10) if generate { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" - c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS @@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) - checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1891b96b3..544b5f0c3 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -1,12 +1,81 @@ package jetstream -import "github.com/nats-io/nats.go" +import ( + "context" + "fmt" -func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { - _ = msg.InProgress() - if f(msg) { - _ = msg.Ack() - } else { - _ = msg.Nak() + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +func JetStreamConsumer( + ctx context.Context, js nats.JetStreamContext, subj, durable string, + f func(ctx context.Context, msg *nats.Msg) bool, + opts ...nats.SubOpt, +) error { + defer func() { + // If there are existing consumers from before they were pull + // consumers, we need to clean up the old push consumers. However, + // in order to not affect the interest-based policies, we need to + // do this *after* creating the new pull consumers, which have + // "Pull" suffixed to their name. + if _, err := js.ConsumerInfo(subj, durable); err == nil { + if err := js.DeleteConsumer(subj, durable); err != nil { + logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable) + } + } + }() + + name := durable + "Pull" + sub, err := js.PullSubscribe(subj, name, opts...) + if err != nil { + return fmt.Errorf("nats.SubscribeSync: %w", err) } + go func() { + for { + // The context behaviour here is surprising — we supply a context + // so that we can interrupt the fetch if we want, but NATS will still + // enforce its own deadline (roughly 5 seconds by default). Therefore + // it is our responsibility to check whether our context expired or + // not when a context error is returned. Footguns. Footguns everywhere. + msgs, err := sub.Fetch(1, nats.Context(ctx)) + if err != nil { + if err == context.Canceled || err == context.DeadlineExceeded { + // Work out whether it was the JetStream context that expired + // or whether it was our supplied context. + select { + case <-ctx.Done(): + // The supplied context expired, so we want to stop the + // consumer altogether. + return + default: + // The JetStream context expired, so the fetch probably + // just timed out and we should try again. + continue + } + } else { + // Something else went wrong, so we'll panic. + logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) + } + } + if len(msgs) < 1 { + continue + } + msg := msgs[0] + if err = msg.InProgress(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + continue + } + if f(ctx, msg) { + if err = msg.Ack(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err)) + } + } else { + if err = msg.Nak(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + } + } + } + }() + return nil } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 5d7937b5c..562b0131e 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -5,20 +5,17 @@ import ( "sync" "time" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" - saramajs "github.com/S7evinK/saramajetstream" natsserver "github.com/nats-io/nats-server/v2/server" - "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go" ) var natsServer *natsserver.Server var natsServerMutex sync.Mutex -func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func Prepare(cfg *config.JetStream) natsclient.JetStreamContext { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { return setupNATS(cfg, nil) @@ -27,13 +24,12 @@ func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sar if natsServer == nil { var err error natsServer, err = natsserver.NewServer(&natsserver.Options{ - ServerName: "monolith", - DontListen: true, - JetStream: true, - StoreDir: string(cfg.StoragePath), - NoSystemAccount: true, - AllowNewAccounts: false, - MaxPayload: 16 * 1024 * 1024, + ServerName: "monolith", + DontListen: true, + JetStream: true, + StoreDir: string(cfg.StoragePath), + NoSystemAccount: true, + MaxPayload: 16 * 1024 * 1024, }) if err != nil { panic(err) @@ -52,20 +48,20 @@ func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sar return setupNATS(cfg, nc) } -func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) natsclient.JetStreamContext { if nc == nil { var err error - nc, err = nats.Connect(strings.Join(cfg.Addresses, ",")) + nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ",")) if err != nil { logrus.WithError(err).Panic("Unable to connect to NATS") - return nil, nil, nil + return nil } } s, err := nc.JetStream() if err != nil { logrus.WithError(err).Panic("Unable to get JetStream context") - return nil, nil, nil + return nil } for _, stream := range streams { // streams are defined in streams.go @@ -80,7 +76,7 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex // If we're trying to keep everything in memory (e.g. unit tests) // then overwrite the storage policy. if cfg.InMemory { - stream.Storage = nats.MemoryStorage + stream.Storage = natsclient.MemoryStorage } // Namespace the streams without modifying the original streams @@ -93,7 +89,5 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex } } - consumer := saramajs.NewJetStreamConsumer(nc, s, "") - producer := saramajs.NewJetStreamProducer(nc, s, "") - return s, consumer, producer + return s } diff --git a/setup/monolith.go b/setup/monolith.go index e6c955222..61125e4a9 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +38,7 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - AccountDB accounts.Database + AccountDB userdb.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 8a35e4143..0af22c19a 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -82,9 +82,15 @@ type EventRelationshipResponse struct { Limited bool `json:"limited"` } -func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +type MSC2836EventRelationshipsResponse struct { + gomatrixserverlib.MSC2836EventRelationshipsResponse + ParsedEvents []*gomatrixserverlib.Event + ParsedAuthChain []*gomatrixserverlib.Event +} + +func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: gomatrixserverlib.ToClientEvents(res.Events, gomatrixserverlib.FormatAll), + Events: gomatrixserverlib.ToClientEvents(res.ParsedEvents, gomatrixserverlib.FormatAll), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -210,7 +216,7 @@ func federatedEventRelationship( // add auth chain information requiredAuthEventsSet := make(map[string]bool) var requiredAuthEvents []string - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { for _, a := range ev.AuthEventIDs() { if requiredAuthEventsSet[a] { continue @@ -227,19 +233,24 @@ func federatedEventRelationship( // they may already have the auth events so don't fail this request util.GetLogger(ctx).WithError(err).Error("Failed to QueryAuthChain") } - res.AuthChain = make([]*gomatrixserverlib.Event, len(queryRes.AuthChain)) + res.AuthChain = make(gomatrixserverlib.EventJSONs, len(queryRes.AuthChain)) for i := range queryRes.AuthChain { - res.AuthChain[i] = queryRes.AuthChain[i].Unwrap() + res.AuthChain[i] = queryRes.AuthChain[i].JSON() + } + + res.Events = make(gomatrixserverlib.EventJSONs, len(res.ParsedEvents)) + for i := range res.ParsedEvents { + res.Events[i] = res.ParsedEvents[i].JSON() } return util.JSONResponse{ Code: 200, - JSON: res, + JSON: res.MSC2836EventRelationshipsResponse, } } -func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { - var res gomatrixserverlib.MSC2836EventRelationshipsResponse +func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONResponse) { + var res MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. event := rc.getLocalEvent(rc.req.EventID) @@ -290,11 +301,11 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons ) returnEvents = append(returnEvents, events...) } - res.Events = make([]*gomatrixserverlib.Event, len(returnEvents)) + res.ParsedEvents = make([]*gomatrixserverlib.Event, len(returnEvents)) for i, ev := range returnEvents { // for each event, extract the children_count | hash and add it as unsigned data. rc.addChildMetadata(ev) - res.Events[i] = ev.Unwrap() + res.ParsedEvents[i] = ev.Unwrap() } res.Limited = remaining == 0 || walkLimited return &res, nil @@ -357,7 +368,7 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H continue } rc.injectResponseToRoomserver(res) - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { if ev.EventID() == eventID { return ev.Headered(ev.Version()) } @@ -384,7 +395,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen if rc.hasUnexploredChildren(parentID) { // we need to do a remote request to pull in the children as we are missing them locally. serversToQuery := rc.getServersForEventID(parentID) - var result *gomatrixserverlib.MSC2836EventRelationshipsResponse + var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, @@ -397,7 +408,12 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen if err != nil { util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships") } else { - result = &res + mscRes := &MSC2836EventRelationshipsResponse{ + MSC2836EventRelationshipsResponse: res, + } + mscRes.ParsedEvents = res.Events.UntrustedEvents(rc.roomVersion) + mscRes.ParsedAuthChain = res.AuthChain.UntrustedEvents(rc.roomVersion) + result = mscRes break } } @@ -467,7 +483,7 @@ func walkThread( } // MSC2836EventRelationships performs an /event_relationships request to a remote server -func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { +func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, @@ -481,7 +497,12 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverli util.GetLogger(rc.ctx).WithError(err).Error("Failed to call MSC2836EventRelationships") return nil, err } - return &res, nil + mscRes := &MSC2836EventRelationshipsResponse{ + MSC2836EventRelationshipsResponse: res, + } + mscRes.ParsedEvents = res.Events.UntrustedEvents(ver) + mscRes.ParsedAuthChain = res.AuthChain.UntrustedEvents(ver) + return mscRes, nil } @@ -550,12 +571,12 @@ func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.Serve return serversToQuery } -func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse { +func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelationshipsResponse { if rc.isFederatedRequest { return nil // we don't query remote servers for remote requests } serversToQuery := rc.getServersForEventID(eventID) - var res *gomatrixserverlib.MSC2836EventRelationshipsResponse + var res *MSC2836EventRelationshipsResponse var err error for _, srv := range serversToQuery { res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion) @@ -577,7 +598,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent if queryRes != nil { // inject all the events into the roomserver then return the event in question rc.injectResponseToRoomserver(queryRes) - for _, ev := range queryRes.Events { + for _, ev := range queryRes.ParsedEvents { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { return ev.Headered(ev.Version()) } @@ -619,12 +640,12 @@ func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent // injectResponseToRoomserver injects the events // into the roomserver as KindOutlier, with auth chains. -func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) { - var stateEvents []*gomatrixserverlib.Event +func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) { + var stateEvents gomatrixserverlib.EventJSONs var messageEvents []*gomatrixserverlib.Event - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { if ev.StateKey() != nil { - stateEvents = append(stateEvents, ev) + stateEvents = append(stateEvents, ev.JSON()) } else { messageEvents = append(messageEvents, ev) } @@ -633,7 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events() + eventsInOrder, err := respState.Events(rc.roomVersion) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") return diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 1ec9beb04..c3650085f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -34,7 +34,7 @@ import ( type OutputClientDataConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -63,45 +63,45 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - userID := msg.Header.Get(jetstream.UserID) - var output eventutil.AccountData - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("client API server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - }).Info("received data from client API server") - - streamPos, err := s.db.UpsertAccountData( - s.ctx, userID, output.RoomID, output.Type, - ) - if err != nil { - sentry.CaptureException(err) - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - log.ErrorKey: err, - }).Panicf("could not save account data") - } - - s.stream.Advance(streamPos) - s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) - +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + userID := msg.Header.Get(jetstream.UserID) + var output eventutil.AccountData + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("client API server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + }).Debug("Received data from client API server") + + streamPos, err := s.db.UpsertAccountData( + s.ctx, userID, output.RoomID, output.Type, + ) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + log.ErrorKey: err, + }).Panicf("could not save account data") + } + + s.stream.Advance(streamPos) + s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 57d69d6fb..392840ece 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -34,7 +34,7 @@ import ( type OutputReceiptEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -64,36 +64,36 @@ func NewOutputReceiptEventConsumer( // Start consuming from EDU api func (s *OutputReceiptEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - streamPos, err := s.db.StoreReceipt( - s.ctx, - output.RoomID, - output.Type, - output.UserID, - output.EventID, - output.Timestamp, - ) - if err != nil { - sentry.CaptureException(err) - return true - } - - s.stream.Advance(streamPos) - s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) - +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + streamPos, err := s.db.StoreReceipt( + s.ctx, + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + sentry.CaptureException(err) + return true + } + + s.stream.Advance(streamPos) + s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 54e689fa1..b0beef063 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -36,7 +36,7 @@ import ( type OutputSendToDeviceEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database serverName gomatrixserverlib.ServerName // our server name @@ -68,52 +68,52 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming from EDU api func (s *OutputSendToDeviceEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - sentry.CaptureException(err) - return true - } - if domain != s.serverName { - return true - } - - util.GetLogger(context.TODO()).WithFields(log.Fields{ - "sender": output.Sender, - "user_id": output.UserID, - "device_id": output.DeviceID, - "event_type": output.Type, - }).Info("sync API received send-to-device event from EDU server") - - streamPos, err := s.db.StoreNewSendForDeviceMessage( - s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, - ) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Errorf("failed to store send-to-device message") - return false - } - - s.stream.Advance(streamPos) - s.notifier.OnNewSendToDevice( - output.UserID, - []string{output.DeviceID}, - types.StreamingToken{SendToDevicePosition: streamPos}, - ) - +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + sentry.CaptureException(err) + return true + } + if domain != s.serverName { + return true + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos, err := s.db.StoreNewSendForDeviceMessage( + s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + sentry.CaptureException(err) + log.WithError(err).Errorf("failed to store send-to-device message") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.StreamingToken{SendToDevicePosition: streamPos}, + ) + + return true } diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index de2f6f950..cae5df8a8 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -35,7 +35,7 @@ import ( type OutputTypingEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string eduCache *cache.EDUCache stream types.StreamProvider @@ -66,41 +66,41 @@ func NewOutputTypingEventConsumer( // Start consuming from EDU api func (s *OutputTypingEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "room_id": output.Event.RoomID, - "user_id": output.Event.UserID, - "typing": output.Event.Typing, - }).Debug("received data from EDU server") - - var typingPos types.StreamPosition - typingEvent := output.Event - if typingEvent.Typing { - typingPos = types.StreamPosition( - s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), - ) - } else { - typingPos = types.StreamPosition( - s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), - ) - } - - s.stream.Advance(typingPos) - s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) - +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from EDU server") + + var typingPos types.StreamPosition + typingEvent := output.Event + if typingEvent.Typing { + typingPos = types.StreamPosition( + s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), + ) + } else { + typingPos = types.StreamPosition( + s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), + ) + } + + s.stream.Advance(typingPos) + s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) + + return true } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 97685cc04..e806f76e6 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -18,84 +18,81 @@ import ( "context" "encoding/json" - "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // Call Start() to begin consuming from the key server. func NewOutputKeyChangeEventConsumer( process *process.ProcessContext, - serverName gomatrixserverlib.ServerName, + cfg *config.SyncAPI, topic string, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, keyAPI api.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, ) *OutputKeyChangeEventConsumer { - - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "syncapi/keychange", - Topic: topic, - Consumer: kafkaConsumer, - PartitionStore: store, - } - s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + serverName: cfg.Matrix.ServerName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } - consumer.ProcessMessage = s.onMessage - return s } // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - return s.keyChangeConsumer.Start() + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -107,9 +104,9 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool { if m.DeviceKeys == nil { - return nil + return true } output := m.DeviceKeys // work out who we need to notify about the new key @@ -120,7 +117,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -131,10 +128,10 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) bool { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -144,7 +141,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -155,5 +152,5 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index e9c4abe88..15485bb35 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -16,6 +16,7 @@ package consumers import ( "context" + "database/sql" "encoding/json" "fmt" @@ -38,7 +39,7 @@ type OutputRoomEventConsumer struct { cfg *config.SyncAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database pduStream types.StreamProvider @@ -73,65 +74,61 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var err error - var output api.OutputEvent - if err = json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - // Ignore redaction events. We will add them to the database when they are - // validated (when we receive OutputTypeRedactedEvent) - event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { - // in the special case where the event redacts itself, just pass the message through because - // we will never see the other part of the pair - if event.Redacts() != event.EventID() { - return true - } - } - err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) - case api.OutputTypeOldRoomEvent: - err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) - case api.OutputTypeNewInviteEvent: - s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) - case api.OutputTypeRetireInviteEvent: - s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) - case api.OutputTypeNewPeek: - s.onNewPeek(s.ctx, *output.NewPeek) - case api.OutputTypeRetirePeek: - s.onRetirePeek(s.ctx, *output.RetirePeek) - case api.OutputTypeRedactedEvent: - err = s.onRedactEvent(s.ctx, *output.RedactedEvent) - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - } - if err != nil { - log.WithError(err).Error("roomserver output log: failed to process event") - return false - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var err error + var output api.OutputEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + switch output.Type { + case api.OutputTypeNewRoomEvent: + // Ignore redaction events. We will add them to the database when they are + // validated (when we receive OutputTypeRedactedEvent) + event := output.NewRoomEvent.Event + if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + // in the special case where the event redacts itself, just pass the message through because + // we will never see the other part of the pair + if event.Redacts() != event.EventID() { + return true + } + } + err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) + case api.OutputTypeOldRoomEvent: + err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) + case api.OutputTypeNewInviteEvent: + s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) + case api.OutputTypeRetireInviteEvent: + s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) + case api.OutputTypeNewPeek: + s.onNewPeek(s.ctx, *output.NewPeek) + case api.OutputTypeRetirePeek: + s.onRetirePeek(s.ctx, *output.RetirePeek) + case api.OutputTypeRedactedEvent: + err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + if err != nil { + log.WithError(err).Error("roomserver output log: failed to process event") + return false + } + + return true } func (s *OutputRoomEventConsumer) onRedactEvent( @@ -311,7 +308,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( ctx context.Context, msg api.OutputRetireInviteEvent, ) { pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID) - if err != nil { + // It's possible we just haven't heard of this invite yet, so + // we should not panic if we try to retire it. + if err != nil && err != sql.ErrNoRows { sentry.CaptureException(err) // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 41efd4a07..37a9e2d39 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,8 +18,8 @@ import ( "context" "strings" - "github.com/Shopify/sarama" keyapi "github.com/matrix-org/dendrite/keyserver/api" + keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -64,8 +64,8 @@ func DeviceListCatchup( } // now also track users who we already share rooms with but who have updated their devices between the two tokens - offset := sarama.OffsetOldest - toOffset := sarama.OffsetNewest + offset := keytypes.OffsetOldest + toOffset := keytypes.OffsetNewest if to > 0 && to > from { toOffset = int64(to) } @@ -282,6 +282,8 @@ func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { if strings.Contains(string(ev.Content), `"join"`) { joinUserIDs = append(joinUserIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"invite"`) { + joinUserIDs = append(joinUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"leave"`) { leaveUserIDs = append(leaveUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"ban"`) { diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index e2ff27395..005a33555 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -39,14 +39,14 @@ func Setup( rsAPI api.RoomserverInternalAPI, cfg *config.SyncAPI, ) { - r0mux := csMux.PathPrefix("/r0").Subrouter() + v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() // TODO: Add AS support for all handlers below. - r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -54,7 +54,7 @@ func Setup( return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter", + v3mux.Handle("/user/{userId}/filter", httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -64,7 +64,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter/{filterId}", + v3mux.Handle("/user/{userId}/filter/{filterId}", httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -74,7 +74,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9cff4cad1..b464ad9cd 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -19,7 +19,6 @@ import ( eduAPI "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -27,8 +26,6 @@ import ( ) type Database interface { - internal.PartitionStorer - MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 39bc233ae..72462459c 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -48,7 +48,7 @@ func AddPublicRoutes( federation *gomatrixserverlib.FederationClient, cfg *config.SyncAPI, ) { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) if err != nil { @@ -65,8 +65,8 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - process, cfg.Matrix.ServerName, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - consumer, keyAPI, rsAPI, syncDB, notifier, + process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), + js, keyAPI, rsAPI, syncDB, notifier, streams.DeviceListStreamProvider, ) if err = keyChangeConsumer.Start(); err != nil { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 68c308d83..c2e8ed01c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { parts := strings.Split(tok[1:], "_") var positions [7]StreamPosition for i, p := range parts { - if i > len(positions) { + if i >= len(positions) { break } var pos int diff --git a/sytest-whitelist b/sytest-whitelist index 7d26c610e..d739313ac 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -590,3 +590,6 @@ Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one +uploading self-signing key notifies over federation +uploading signed devices gets propagated over federation +Device list doesn't change if remote server is down \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index 04609659c..2be662e55 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -18,12 +18,15 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + LoginTokenInternalAPI + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error @@ -351,6 +354,7 @@ type Device struct { // If the device is for an appservice user, // this is the appservice ID. AppserviceID string + AccountType AccountType } // Account represents a Matrix account on this home server. @@ -359,7 +363,7 @@ type Account struct { Localpart string ServerName gomatrixserverlib.ServerName AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest + AccountType AccountType // TODO: Associations (e.g. with application services) } @@ -415,4 +419,8 @@ const ( AccountTypeUser AccountType = 1 // AccountTypeGuest indicates this is a guest account AccountTypeGuest AccountType = 2 + // AccountTypeAdmin indicates this is an admin account + AccountTypeAdmin AccountType = 3 + // AccountTypeAppService indicates this is an appservice account + AccountTypeAppService AccountType = 4 ) diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go new file mode 100644 index 000000000..e2207bb53 --- /dev/null +++ b/userapi/api/api_logintoken.go @@ -0,0 +1,76 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "time" +) + +// DefaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const DefaultLoginTokenLifetime = 2 * time.Minute + +type LoginTokenInternalAPI interface { + // PerformLoginTokenCreation creates a new login token and associates it with the provided data. + PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error + + // PerformLoginTokenDeletion ensures the token doesn't exist. Success + // is returned even if the token didn't exist, or was already deleted. + PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error + + // QueryLoginToken returns the data associated with a login token. If + // the token is not valid, success is returned, but res.Data == nil. + QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error +} + +// LoginTokenData is the data that can be retrieved given a login token. This is +// provided by the calling code. +type LoginTokenData struct { + // UserID is the full mxid of the user. + UserID string +} + +// LoginTokenMetadata contains metadata created and maintained by the User API. +type LoginTokenMetadata struct { + Token string + Expiration time.Time +} + +type PerformLoginTokenCreationRequest struct { + Data LoginTokenData +} + +type PerformLoginTokenCreationResponse struct { + Metadata LoginTokenMetadata +} + +type PerformLoginTokenDeletionRequest struct { + Token string +} + +type PerformLoginTokenDeletionResponse struct{} + +type QueryLoginTokenRequest struct { + Token string +} + +type QueryLoginTokenResponse struct { + // Data is nil if the token was invalid. + Data *LoginTokenData +} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go new file mode 100644 index 000000000..e60dae594 --- /dev/null +++ b/userapi/api/api_trace_logintoken.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/util" +) + +func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { + err := t.Impl.PerformLoginTokenCreation(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { + err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { + err := t.Impl.QueryLoginToken(ctx, req, res) + util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) + return err +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5d91383de..f54cc6137 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -21,22 +21,21 @@ import ( "errors" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/userapi/storage" ) type UserInternalAPI struct { - AccountDB accounts.Database - DeviceDB devices.Database + DB storage.Database ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService @@ -54,20 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) + return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - if req.AccountType == api.AccountTypeGuest { - acc, err := a.AccountDB.CreateGuestAccount(ctx) - if err != nil { - return err - } - res.AccountCreated = true - res.Account = acc - return nil - } - acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -86,11 +76,18 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P Localpart: req.Localpart, ServerName: a.ServerName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + AccountType: req.AccountType, } return nil } - if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if req.AccountType == api.AccountTypeGuest { + res.AccountCreated = true + res.Account = acc + return nil + } + + if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -100,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { return err } res.PasswordUpdated = true @@ -113,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -138,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) } if err != nil { return err @@ -197,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil @@ -209,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -224,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -262,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if domain != a.ServerName { return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) } - prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { if err == sql.ErrNoRows { return nil @@ -276,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) if err != nil { return err } @@ -285,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer } func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { - devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { return err } @@ -313,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if domain != a.ServerName { return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) } - devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { return err } + res.UserExists = true res.Devices = devs return nil } @@ -331,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } @@ -349,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local) if err != nil { return err } @@ -368,13 +366,22 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } - device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { if err == sql.ErrNoRows { return nil } return err } + localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return err + } + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + if err != nil { + return err + } + device.AccountType = acc.AccountType res.Device = device return nil } @@ -401,6 +408,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID, + AccountType: api.AccountTypeAppService, } localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) @@ -410,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -428,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart) res.AccountDeactivated = err == nil return err } @@ -437,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { token := util.RandomString(24) - exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) + exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) res.Token = api.OpenIDToken{ Token: token, @@ -450,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) + openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) if err != nil { return err } @@ -472,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } return nil } - exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { res.Error = fmt.Sprintf("failed to delete backup: %s", err) } @@ -485,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Create metadata if req.Version == "" { - version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to create backup: %s", err) } @@ -498,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Update metadata if len(req.Keys.Rooms) == 0 { - err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } @@ -519,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { // you can only upload keys for the CURRENT version - version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return @@ -547,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform }) } } - count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { res.Error = fmt.Sprintf("failed to upsert keys: %s", err) return @@ -557,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { @@ -573,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB res.Exists = !deleted if !req.ReturnKeys { - res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } return } - result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) return diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go new file mode 100644 index 000000000..f1bf391e4 --- /dev/null +++ b/userapi/internal/api_logintoken.go @@ -0,0 +1,78 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// PerformLoginTokenCreation creates a new login token and associates it with the provided data. +func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error { + util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation") + _, domain, err := gomatrixserverlib.SplitID('@', req.Data.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + } + tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) + if err != nil { + return err + } + res.Metadata = *tokenMeta + return nil +} + +// PerformLoginTokenDeletion ensures the token doesn't exist. +func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { + util.GetLogger(ctx).Info("PerformLoginTokenDeletion") + return a.DB.RemoveLoginToken(ctx, req.Token) +} + +// QueryLoginToken returns the data associated with a login token. If +// the token is not valid, success is returned, but res.Data == nil. +func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { + tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token) + if err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', tokenData.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + } + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Data = tokenData + return nil +} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go new file mode 100644 index 000000000..366a97099 --- /dev/null +++ b/userapi/inthttp/client_logintoken.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +const ( + PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" + PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" + QueryLoginTokenPath = "/userapi/queryLoginToken" +) + +func (h *httpUserInternalAPI) PerformLoginTokenCreation( + ctx context.Context, + request *api.PerformLoginTokenCreationRequest, + response *api.PerformLoginTokenCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformLoginTokenDeletion( + ctx context.Context, + request *api.PerformLoginTokenDeletionRequest, + response *api.PerformLoginTokenDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryLoginToken( + ctx context.Context, + request *api.QueryLoginTokenRequest, + response *api.QueryLoginTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") + defer span.Finish() + + apiURL := h.apiURL + QueryLoginTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ac05bcd09..d00ee042c 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -27,6 +27,8 @@ import ( // nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + addRoutesLoginToken(internalAPIMux, s) + internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { request := api.PerformAccountCreationRequest{} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go new file mode 100644 index 000000000..1f2eb34b9 --- /dev/null +++ b/userapi/inthttp/server_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// addRoutesLoginToken adds routes for all login token API calls. +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformLoginTokenCreationPath, + httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenCreationRequest{} + response := api.PerformLoginTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformLoginTokenDeletionPath, + httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenDeletionRequest{} + response := api.PerformLoginTokenDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryLoginTokenPath, + httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { + request := api.QueryLoginTokenRequest{} + response := api.QueryLoginTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go deleted file mode 100644 index 2f8290623..000000000 --- a/userapi/storage/accounts/postgres/storage.go +++ /dev/null @@ -1,520 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - - // Import the postgres database driver. - _ "github.com/lib/pq" -) - -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 -} - -// NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) - if err != nil { - return nil, err - } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) - if err != nil { - return err - } - return d.accounts.updatePassword(ctx, localpart, hash) -} - -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, sqlutil.ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) - return err - }) - return -} - -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*api.Account, error) { - var account *api.Account - var err error - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { - if sqlutil.IsUniqueConstraintViolationErr(err) { - return nil, sqlutil.ErrUserExists - } - return nil, err - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) -} - -// CreateOpenIDToken persists a new token that was issued through OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - return nil - }) - return -} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go deleted file mode 100644 index 95fe99f33..000000000 --- a/userapi/storage/devices/interface.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "context" - - "github.com/matrix-org/dendrite/userapi/api" -) - -type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) - GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - // CreateDevice makes a new device associated with the given user ID localpart. - // If there is already a device with the same device ID for this user, that access token will be revoked - // and replaced with the given accessToken. If the given accessToken is already in use for another device, - // an error will be returned. - // If no device ID is given one is generated. - // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error - // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) -} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go deleted file mode 100644 index 485234331..000000000 --- a/userapi/storage/devices/postgres/storage.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -// The length of generated device IDs -var deviceIDByteLength = 6 - -// Database represents a device database. -type Database struct { - db *sql.DB - devices devicesStatements -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - d := devicesStatements{} - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.prepare(db, serverName); err != nil { - return nil, err - } - - return &Database{db, d}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go deleted file mode 100644 index 538644837..000000000 --- a/userapi/storage/devices/sqlite3/storage.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -// The length of generated device IDs -var deviceIDByteLength = 6 - -// Database represents a device database. -type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err = d.prepare(db, writer, serverName); err != nil { - return nil, err - } - return &Database{db, writer, d}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go deleted file mode 100644 index 3c2034300..000000000 --- a/userapi/storage/devices/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package devices - -import ( - "fmt" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go deleted file mode 100644 index f360f9857..000000000 --- a/userapi/storage/devices/storage_wasm.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "fmt" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -func NewDatabase( - dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, -) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go similarity index 65% rename from userapi/storage/accounts/interface.go rename to userapi/storage/interface.go index f03b3774c..a131dac47 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/interface.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "context" @@ -32,8 +32,7 @@ type Database interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) - CreateGuestAccount(ctx context.Context) (*api.Account, error) + CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given @@ -61,6 +60,35 @@ type Database interface { UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) + + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go similarity index 89% rename from userapi/storage/accounts/postgres/account_data_table.go rename to userapi/storage/postgres/account_data_table.go index 8ba890e75..67113367b 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,19 +57,20 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, rows.Err() } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go similarity index 80% rename from userapi/storage/accounts/postgres/accounts_table.go rename to userapi/storage/postgres/accounts_table.go index b57aa901f..92311d56d 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -39,16 +41,18 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT FALSE + is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type SMALLINT NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -57,7 +61,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" @@ -75,14 +79,15 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -95,17 +100,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +func (s *accountsStatements) InsertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error - if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -116,38 +121,39 @@ func (s *accountsStatements) insertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -164,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go similarity index 92% rename from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go rename to userapi/storage/postgres/deltas/20200929203058_is_active.go index 9e14286e0..32d3235be 100644 --- a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go similarity index 89% rename from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index 290f854c8..1bbb0a9d3 100644 --- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go new file mode 100644 index 000000000..2fae00cb9 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go @@ -0,0 +1,34 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + ) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go similarity index 85% rename from userapi/storage/devices/postgres/devices_table.go rename to userapi/storage/postgres/devices_table.go index 7de9f5f9e..7bc5dc69b 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,50 +112,32 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return + if err != nil { + return nil, err } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.deleteDevicesStmt, deleteDevicesSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice( } // deleteDevice removes a single device by id and user localpart. -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) if err != nil { return nil, err @@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go similarity index 91% rename from userapi/storage/accounts/postgres/key_backup_table.go rename to userapi/storage/postgres/key_backup_table.go index c1402d4d2..ac0e80617 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go similarity index 90% rename from userapi/storage/accounts/postgres/key_backup_version_table.go rename to userapi/storage/postgres/key_backup_version_table.go index d73447b49..e78e4cd51 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -69,12 +70,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go new file mode 100644 index 000000000..4de96f839 --- /dev/null +++ b/userapi/storage/postgres/logintoken_table.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/util" +) + +const loginTokenSchema = ` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +` + +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go similarity index 74% rename from userapi/storage/accounts/postgres/openid_table.go rename to userapi/storage/postgres/openid_table.go index 190d141b7..29c3ddcb4 100644 --- a/userapi/storage/accounts/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return +func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go similarity index 85% rename from userapi/storage/accounts/postgres/profile_table.go rename to userapi/storage/postgres/profile_table.go index 9313864be..32a4b5506 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -59,12 +60,13 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(profilesSchema) +func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{} + _, err := db.Exec(profilesSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, +func (s *profilesStatements) SetAvatarURL( + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) return } -func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, +func (s *profilesStatements) SetDisplayName( + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go new file mode 100644 index 000000000..ac5c59b81 --- /dev/null +++ b/userapi/storage/postgres/storage.go @@ -0,0 +1,105 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/shared" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet + return nil, err + } + deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) + deltas.LoadAddAccountType(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } + accountsTable, err := NewPostgresAccountsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) + } + devicesTable, err := NewPostgresDevicesTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) + } + keyBackupTable, err := NewPostgresKeyBackupTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err) + } + keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err) + } + loginTokenTable, err := NewPostgresLoginTokenTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err) + } + openIDTable, err := NewPostgresOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err) + } + profilesTable, err := NewPostgresProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err) + } + threePIDTable, err := NewPostgresThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewDummyWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go similarity index 83% rename from userapi/storage/accounts/postgres/threepid_table.go rename to userapi/storage/postgres/threepid_table.go index 9280fc87c..63c08d61f 100644 --- a/userapi/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -58,12 +59,13 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(threepidSchema) +func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{} + _, err := db.Exec(threepidSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID( return } -func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) +func (s *threepidStatements) DeleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/shared/storage.go similarity index 51% rename from userapi/storage/accounts/sqlite3/storage.go rename to userapi/storage/shared/storage.go index 2b731b759..5f1f95005 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/shared/storage.go @@ -12,117 +12,66 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sqlite3 +package shared import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" "strconv" - "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) // Database represents an account database type Database struct { - db *sql.DB - writer sqlutil.Writer - - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 - - accountsMu sync.Mutex - profilesMu sync.Mutex - accountDatasMu sync.Mutex - threepidsMu sync.Mutex + DB *sql.DB + Writer sqlutil.Writer + Accounts tables.AccountsTable + Profiles tables.ProfileTable + AccountDatas tables.AccountDataTable + ThreePIDs tables.ThreePIDTable + OpenIDTokens tables.OpenIDTable + KeyBackups tables.KeyBackupTable + KeyBackupVersions tables.KeyBackupVersionTable + Devices tables.DevicesTable + LoginTokens tables.LoginTokenTable + LoginTokenLifetime time.Duration + ServerName gomatrixserverlib.ServerName + BcryptCost int + OpenIDTokenLifetimeMS int64 } -// NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - - return d, nil -} +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, ) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) if err != nil { return nil, err } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.accounts.selectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) } // GetProfileByLocalpart returns the profile associated with the given localpart. @@ -130,7 +79,7 @@ func (d *Database) GetAccountByPassword( func (d *Database) GetProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart) } // SetAvatarURL updates the avatar URL of the profile associated with the given @@ -138,10 +87,8 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) }) } @@ -150,10 +97,8 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) }) } @@ -165,53 +110,30 @@ func (d *Database) SetPassword( if err != nil { return err } - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.updatePassword(ctx, localpart, hash) + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePassword(ctx, localpart, hash) }) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - // We need to lock so we sequentially create numeric localparts. If we don't, two calls to - // this function will cause the same number to be selected and one will fail with 'database is locked' - // when the first txn upgrades to a write txn. We also need to lock the account creation else we can - // race with CreateAccount - // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { - // Create one account at a time else we can get 'database is locked'. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) return err }) return @@ -220,7 +142,7 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -232,13 +154,13 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -260,10 +182,8 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - d.accountDatasMu.Lock() - defer d.accountDatasMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -275,7 +195,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( rooms map[string]map[string]json.RawMessage, err error, ) { - return d.accountDatas.selectAccountData(ctx, localpart) + return d.AccountDatas.SelectAccountData(ctx, localpart) } // GetAccountDataByType returns account data matching a given @@ -285,7 +205,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( + return d.AccountDatas.SelectAccountDataByType( ctx, localpart, roomID, dataType, ) } @@ -294,11 +214,11 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost) return string(hashBytes), err } @@ -313,10 +233,8 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + user, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -327,7 +245,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) }) } @@ -338,10 +256,8 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium) }) } @@ -353,7 +269,7 @@ func (d *Database) RemoveThreePIDAssociation( func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) + return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } // GetThreePIDsForLocalpart looks up the third-party identifiers associated with @@ -363,14 +279,14 @@ func (d *Database) GetLocalpartForThreePID( func (d *Database) GetThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) } // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) if err == sql.ErrNoRows { return true, nil } @@ -382,20 +298,20 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) } // SearchProfiles returns all profiles where the provided localpart or display name // match any part of the profiles in the database. func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) + return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit) } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.deactivateAccount(ctx, localpart) + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.DeactivateAccount(ctx, localpart) }) } @@ -404,9 +320,9 @@ func (d *Database) CreateOpenIDToken( ctx context.Context, token, localpart string, ) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) }) return expiresAtMS, err } @@ -416,14 +332,14 @@ func (d *Database) GetOpenIDTokenAttributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) + return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token) } func (d *Database) CreateKeyBackup( ctx context.Context, userID, algorithm string, authData json.RawMessage, ) (version string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "") return err }) return @@ -432,8 +348,8 @@ func (d *Database) CreateKeyBackup( func (d *Database) UpdateKeyBackupAuthData( ctx context.Context, userID, version string, authData json.RawMessage, ) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData) }) return } @@ -441,8 +357,8 @@ func (d *Database) UpdateKeyBackupAuthData( func (d *Database) DeleteKeyBackup( ctx context.Context, userID, version string, ) (exists bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version) return err }) return @@ -451,8 +367,8 @@ func (d *Database) DeleteKeyBackup( func (d *Database) GetKeyBackup( ctx context.Context, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) return err }) return @@ -461,16 +377,16 @@ func (d *Database) GetKeyBackup( func (d *Database) GetBackupKeys( ctx context.Context, version, userID, filterRoomID, filterSessionID string, ) (result map[string]map[string]api.KeyBackupSession, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) return err } if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID) return err } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version) return err }) return @@ -479,8 +395,8 @@ func (d *Database) GetBackupKeys( func (d *Database) CountBackupKeys( ctx context.Context, version, userID string, ) (count int64, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) if err != nil { return err } @@ -494,8 +410,8 @@ func (d *Database) UpsertBackupKeys( ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, ) (count int64, etag string, err error) { // wrap the following logic in a txn to ensure we atomically upload keys - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) if err != nil { return err } @@ -503,7 +419,7 @@ func (d *Database) UpsertBackupKeys( return fmt.Errorf("backup was deleted") } // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version) if err != nil { return err } @@ -517,10 +433,10 @@ func (d *Database) UpsertBackupKeys( existingSession, ok := existingRoom[newKey.SessionID] if ok { if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err) } } // if we shouldn't replace the key we do nothing with it @@ -528,14 +444,14 @@ func (d *Database) UpsertBackupKeys( } } // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err) } } - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) if err != nil { return err } @@ -552,7 +468,7 @@ func (d *Database) UpsertBackupKeys( newETag = strconv.FormatInt(oldETagInt+1, 10) } etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag) } else { etag = oldETag } @@ -561,3 +477,196 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.Devices.SelectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.LoginTokenLifetime), + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.DeleteLoginToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.LoginTokens.SelectLoginToken(ctx, token) +} diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go similarity index 88% rename from userapi/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/sqlite3/account_data_table.go index 871f996e0..cfd8568a9 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,27 +57,29 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(accountDataSchema) - if err != nil { - return +func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(accountDataSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return err } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, nil } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go similarity index 80% rename from userapi/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/sqlite3/accounts_table.go index 8a7c8fba7..e6c37e58e 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -39,14 +41,16 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT 0 + is_deactivated BOOLEAN DEFAULT 0, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type INTEGER NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -55,7 +59,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" @@ -74,15 +78,16 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -95,17 +100,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +func (s *accountsStatements) InsertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error - if appserviceID == "" { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -119,35 +124,35 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -164,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go similarity index 100% rename from userapi/storage/accounts/sqlite3/constraint_wasm.go rename to userapi/storage/sqlite3/constraint_wasm.go diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go similarity index 96% rename from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go rename to userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9fddb05a1..c69614e83 100644 --- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go similarity index 94% rename from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 262098265..ebf908001 100644 --- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go new file mode 100644 index 000000000..9b058dedd --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -0,0 +1,54 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func init() { + goose.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp; +CREATE TABLE account_accounts ( + localpart TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL, + password_hash TEXT, + appservice_id TEXT, + is_deactivated BOOLEAN DEFAULT 0, + account_type INTEGER NOT NULL +); +INSERT + INTO account_accounts ( + localpart, created_ts, password_hash, appservice_id, account_type + ) SELECT + localpart, created_ts, password_hash, appservice_id, 1 + FROM account_accounts_tmp +; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE account_accounts_tmp;`) + if err != nil { + return fmt.Errorf("failed to add column: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go similarity index 82% rename from userapi/storage/devices/sqlite3/devices_table.go rename to userapi/storage/sqlite3/devices_table.go index 955d8ac7f..423640e90 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" @@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -98,52 +98,33 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.writer = writer - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return + if err != nil { + return nil, err } - if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDevicesCountStmt, selectDevicesCountSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -169,7 +150,7 @@ func (s *devicesStatements) insertDevice( }, nil } -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -177,7 +158,7 @@ func (s *devicesStatements) deleteDevice( return err } -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) @@ -195,7 +176,7 @@ func (s *devicesStatements) deleteDevices( return err } -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -203,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -211,7 +192,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -227,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -244,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -285,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) iDeviceIDs := make([]interface{}, len(deviceIDs)) for i := range deviceIDs { @@ -314,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go similarity index 91% rename from userapi/storage/accounts/sqlite3/key_backup_table.go rename to userapi/storage/sqlite3/key_backup_table.go index 837d38cf1..81726edf9 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go similarity index 89% rename from userapi/storage/accounts/sqlite3/key_backup_version_table.go rename to userapi/storage/sqlite3/key_backup_version_table.go index 4211ed0f1..e85e6f08b 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -67,12 +68,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go new file mode 100644 index 000000000..78d42029a --- /dev/null +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +const loginTokenSchema = ` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +` + +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go similarity index 74% rename from userapi/storage/accounts/sqlite3/openid_table.go rename to userapi/storage/sqlite3/openid_table.go index 98c0488b1..d6090e0da 100644 --- a/userapi/storage/accounts/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return err +func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + db: db, + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go similarity index 89% rename from userapi/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/sqlite3/profile_table.go index a92e95663..d85b19c7b 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -60,13 +61,15 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(profilesSchema) - if err != nil { - return +func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(profilesSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return err } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( +func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) @@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL( return } -func (s *profilesStatements) setDisplayName( +func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) @@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName( return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..98c244977 --- /dev/null +++ b/userapi/storage/sqlite3/storage.go @@ -0,0 +1,106 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet + return nil, err + } + deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) + deltas.LoadAddAccountType(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } + accountsTable, err := NewSQLiteAccountsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) + } + devicesTable, err := NewSQLiteDevicesTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) + } + keyBackupTable, err := NewSQLiteKeyBackupTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err) + } + keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err) + } + loginTokenTable, err := NewSQLiteLoginTokenTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err) + } + openIDTable, err := NewSQLiteOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err) + } + profilesTable, err := NewSQLiteProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err) + } + threePIDTable, err := NewSQLiteThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go similarity index 88% rename from userapi/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/sqlite3/threepid_table.go index 9dc0e2d22..fa174eed5 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -60,13 +61,15 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(threepidSchema) - if err != nil { - return +func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(threepidSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return threepids, rows.Err() } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID( return err } -func (s *threepidStatements) deleteThreePID( +func (s *threepidStatements) DeleteThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium) diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go similarity index 81% rename from userapi/storage/accounts/storage.go rename to userapi/storage/storage.go index a21f7d94e..4711439af 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/storage.go @@ -15,25 +15,27 @@ //go:build !wasm // +build !wasm -package accounts +package storage import ( "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go similarity index 87% rename from userapi/storage/accounts/storage_wasm.go rename to userapi/storage/storage_wasm.go index 11a88a20a..701dcd833 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) @@ -27,10 +28,11 @@ func NewDatabase( serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go new file mode 100644 index 000000000..12939ced5 --- /dev/null +++ b/userapi/storage/tables/interface.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" +) + +type AccountDataTable interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) +} + +type AccountsTable interface { + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string) (err error) + SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type DevicesTable interface { + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) + SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error +} + +type KeyBackupTable interface { + CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error) + InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error) +} + +type KeyBackupVersionTable interface { + InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error + UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error + DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error) + SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) +} + +type LoginTokenTable interface { + InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error + DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error + SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} + +type OpenIDTable interface { + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) +} + +type ProfileTable interface { + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error + SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) +} + +type ThreePIDTable interface { + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) +} diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020a..4a5793abb 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,14 +15,15 @@ package userapi import ( + "time" + "github.com/gorilla/mux" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) @@ -35,16 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } + return newInternalAPI(db, cfg, appServices, keyAPI) +} + +func newInternalAPI( + db storage.Database, + cfg *config.UserAPI, + appServices []config.ApplicationService, + keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { return &internal.UserInternalAPI{ - AccountDB: accountDB, - DeviceDB: deviceDB, + DB: db, ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e6..4214c07f7 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -1,4 +1,18 @@ -package userapi_test +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi import ( "context" @@ -6,49 +20,56 @@ import ( "net/http" "reflect" "testing" + "time" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/userapi/storage" ) const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) +type apiTestOpts struct { + loginTokenLifetime time.Duration +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) { + if opts.loginTokenLifetime == 0 { + opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond + } + dbopts := &config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + } + accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime) if err != nil { t.Fatalf("failed to create account DB: %s", err) } + cfg := &config.UserAPI{ - DeviceDatabase: config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, Matrix: &config.Global{ ServerName: serverName, }, } - return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -106,7 +127,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + AddInternalRoutes(router, userAPI) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) @@ -119,3 +140,115 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestLoginToken(t *testing.T) { + ctx := context.Background() + + t.Run("tokenLoginFlow", func(t *testing.T) { + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + + t.Log("Creating a login token like the SSO callback would...") + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } + + t.Log("Querying the login token like /login with m.login.token would...") + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } else if want := "@auser:example.com"; qresp.Data.UserID != want { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } + + t.Log("Deleting the login token like /login with m.login.token would...") + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) + + t.Run("expiredTokenIsNotReturned", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteWorks", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteUnknownIsNoOp", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) +}