diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md index f40c56609..a3b8c0754 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md @@ -17,7 +17,6 @@ see: https://www.matrix.org/security-disclosure-policy/ ### Background information - **Dendrite version or git SHA**: -- **Monolith or Polylith?**: - **SQLite3 or Postgres?**: - **Running in Docker?**: - **`go version`**: diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 1de39850d..2f615a6a4 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -25,7 +25,7 @@ jobs: - name: Install Go uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: "stable" cache: true - name: Install Node @@ -62,14 +62,14 @@ jobs: - name: Install Go uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: "stable" - name: golangci-lint uses: golangci/golangci-lint-action@v3 # run go test with different go versions test: timeout-minutes: 10 - name: Unit tests (Go ${{ matrix.go }}) + name: Unit tests runs-on: ubuntu-latest # Service containers to run with `container-job` services: @@ -91,25 +91,21 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.19"] steps: - uses: actions/checkout@v3 - name: Setup go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - uses: actions/cache@v3 # manually set up caches, as they otherwise clash with different steps using setup-go with cache=true with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-unit-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-unit-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-unit- + ${{ runner.os }}-go-stable-unit- - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: @@ -130,7 +126,6 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.18", "1.19"] goos: ["linux"] goarch: ["amd64", "386"] steps: @@ -138,15 +133,15 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}- - name: Install dependencies x86 if: ${{ matrix.goarch == '386' }} run: sudo apt update && sudo apt-get install -y gcc-multilib @@ -164,23 +159,22 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.18", "1.19"] goos: ["windows"] goarch: ["amd64"] steps: - uses: actions/checkout@v3 - - name: Setup Go ${{ matrix.go }} + - name: Setup Go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}- - name: Install dependencies run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - env: @@ -206,7 +200,7 @@ jobs: integration: timeout-minutes: 20 needs: initial-tests-done - name: Integration tests (Go ${{ matrix.go }}) + name: Integration tests runs-on: ubuntu-latest # Service containers to run with `container-job` services: @@ -228,16 +222,12 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.19"] steps: - uses: actions/checkout@v3 - name: Setup go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: @@ -248,9 +238,9 @@ jobs: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-test-race-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test-race- + ${{ runner.os }}-go-stable-test-race- - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt env: POSTGRES_HOST: localhost @@ -261,6 +251,7 @@ jobs: uses: codecov/codecov-action@v3 with: flags: unittests + fail_ci_if_error: true # run database upgrade tests upgrade_test: @@ -273,7 +264,7 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: "1.18" + go-version: "stable" cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests @@ -293,7 +284,7 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: "1.18" + go-version: "stable" cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests @@ -317,19 +308,9 @@ jobs: - label: SQLite Cgo cgo: 1 - - label: SQLite native, full HTTP APIs - api: full-http - - - label: SQLite Cgo, full HTTP APIs - api: full-http - cgo: 1 - - label: PostgreSQL postgres: postgres - - label: PostgreSQL, full HTTP APIs - postgres: postgres - api: full-http container: image: matrixdotorg/sytest-dendrite volumes: @@ -338,7 +319,6 @@ jobs: - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} - API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} CGO_ENABLED: ${{ matrix.cgo && 1 }} steps: @@ -390,22 +370,9 @@ jobs: - label: SQLite Cgo cgo: 1 - - label: SQLite native, full HTTP APIs - api: full-http - cgo: 0 - - - label: SQLite Cgo, full HTTP APIs - api: full-http - cgo: 1 - - label: PostgreSQL postgres: Postgres cgo: 0 - - - label: PostgreSQL, full HTTP APIs - postgres: Postgres - api: full-http - cgo: 0 steps: # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path @@ -447,7 +414,7 @@ jobs: (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done # Build initial Dendrite image - - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite env: DOCKER_BUILDKIT: 1 @@ -459,8 +426,8 @@ jobs: shell: bash name: Run Complement Tests env: - COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} - API: ${{ matrix.api && 1 }} + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ working-directory: complement integration-tests-done: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 2e17539d8..0c3053a56 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -61,7 +61,6 @@ jobs: cache-to: type=gha,mode=max context: . build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: monolith platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -77,7 +76,6 @@ jobs: cache-to: type=gha,mode=max context: . build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: monolith platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -98,86 +96,6 @@ jobs: with: sarif_file: "trivy-results.sarif" - polylith: - name: Polylith image - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - security-events: write # To upload Trivy sarif files - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Get release tag & build flags - if: github.event_name == 'release' # Only for GitHub releases - run: | - echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV - BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) - [ ${BRANCH} == "main" ] && BRANCH="" - echo "BRANCH=${BRANCH}" >> $GITHUB_ENV - - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Login to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ env.DOCKER_HUB_USER }} - password: ${{ secrets.DOCKER_TOKEN }} - - name: Login to GitHub Containers - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Build main polylith image - if: github.ref_name == 'main' - id: docker_build_polylith - uses: docker/build-push-action@v3 - with: - cache-from: type=gha - cache-to: type=gha,mode=max - context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: polylith - platforms: ${{ env.PLATFORMS }} - push: true - tags: | - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - - - name: Build release polylith image - if: github.event_name == 'release' # Only for GitHub releases - id: docker_build_polylith_release - uses: docker/build-push-action@v3 - with: - cache-from: type=gha - cache-to: type=gha,mode=max - context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: polylith - platforms: ${{ env.PLATFORMS }} - push: true - tags: | - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:latest - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} - - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - demo-pinecone: name: Pinecone demo image runs-on: ubuntu-latest diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 7bf08bba1..9df3cceae 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -4,7 +4,7 @@ name: Deploy GitHub Pages dependencies preinstalled on: # Runs on pushes targeting the default branch push: - branches: ["main", "gh-pages"] + branches: ["gh-pages"] paths: - 'docs/**' # only execute if we have docs changes diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index d2a1f6e1f..dff9b34c9 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -19,17 +19,13 @@ jobs: fail-fast: false matrix: include: - - label: SQLite + - label: SQLite native - - label: SQLite, full HTTP APIs - api: full-http + - label: SQLite Cgo + cgo: 1 - label: PostgreSQL postgres: postgres - - - label: PostgreSQL, full HTTP APIs - postgres: postgres - api: full-http container: image: matrixdotorg/sytest-dendrite:latest volumes: @@ -38,9 +34,9 @@ jobs: - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} - API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} RACE_DETECTION: 1 + COVER: 1 steps: - uses: actions/checkout@v3 - uses: actions/cache@v3 @@ -73,3 +69,192 @@ jobs: path: | /logs/results.tap /logs/**/*.log* + + sytest-coverage: + timeout-minutes: 5 + name: "Sytest Coverage" + runs-on: ubuntu-latest + needs: sytest # only run once Sytest is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 'stable' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'integrationcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > sytest.cov + go tool cover -func=sytest.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./sytest.cov + flags: sytest + fail_ci_if_error: true + + # run Complement + complement: + name: "Complement (${{ matrix.label }})" + timeout-minutes: 60 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - label: SQLite native + cgo: 0 + + - label: SQLite Cgo + cgo: 1 + + - label: PostgreSQL + postgres: Postgres + cgo: 0 + steps: + # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. + # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path + - name: "Set Go Version" + run: | + echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH + echo "~/go/bin" >> $GITHUB_PATH + - name: "Install Complement Dependencies" + # We don't need to install Go because it is included on the Ubuntu 20.04 image: + # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 + run: | + sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev + go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + - name: Run actions/checkout@v3 for dendrite + uses: actions/checkout@v3 + with: + path: dendrite + + # Attempt to check out the same branch of Complement as the PR. If it + # doesn't exist, fallback to main. + - name: Checkout complement + shell: bash + run: | + mkdir -p complement + # Attempt to use the version of complement which best matches the current + # build. Depending on whether this is a PR or release, etc. we need to + # use different fallbacks. + # + # 1. First check if there's a similarly named branch (GITHUB_HEAD_REF + # for pull requests, otherwise GITHUB_REF). + # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y + # (GITHUB_BASE_REF for pull requests). + # 3. Use the default complement branch ("master"). + for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "master"; do + # Skip empty branch names and merge commits. + if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then + continue + fi + (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break + done + # Build initial Dendrite image + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + working-directory: dendrite + env: + DOCKER_BUILDKIT: 1 + + - name: Create post test script + run: | + cat < /tmp/posttest.sh + #!/bin/bash + mkdir -p /tmp/Complement/logs/\$2/\$1/ + docker cp \$1:/dendrite/complementcover.log /tmp/Complement/logs/\$2/\$1/ + EOF + + chmod +x /tmp/posttest.sh + # Run Complement + - run: | + set -o pipefail && + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + shell: bash + name: Run Complement Tests + env: + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ + COMPLEMENT_DENDRITE_COVER: 1 + COMPLEMENT_POST_TEST_SCRIPT: /tmp/posttest.sh + working-directory: complement + + - name: Upload Complement logs + uses: actions/upload-artifact@v2 + if: ${{ always() }} + with: + name: Complement Logs - (Dendrite, ${{ join(matrix.*, ', ') }}) + path: | + /tmp/Complement/**/complementcover.log + + complement-coverage: + timeout-minutes: 5 + name: "Complement Coverage" + runs-on: ubuntu-latest + needs: complement # only run once Complement is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 'stable' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'complementcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > complement.cov + go tool cover -func=complement.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./complement.cov + flags: complement + fail_ci_if_error: true + + element_web: + timeout-minutes: 120 + runs-on: ubuntu-latest + steps: + - uses: tecolicom/actions-use-apt-tools@v1 + with: + # Our test suite includes some screenshot tests with unusual diacritics, which are + # supposed to be covered by STIXGeneral. + tools: fonts-stix + - uses: actions/checkout@v2 + with: + repository: matrix-org/matrix-react-sdk + - uses: actions/setup-node@v3 + with: + cache: 'yarn' + - name: Fetch layered build + run: scripts/ci/layered.sh + - name: Copy config + run: cp element.io/develop/config.json config.json + working-directory: ./element-web + - name: Build + env: + CI_PACKAGE: true + NODE_OPTIONS: "--openssl-legacy-provider" + run: yarn build + working-directory: ./element-web + - name: Edit Test Config + run: | + sed -i '/HOMESERVER/c\ HOMESERVER: "dendrite",' cypress.config.ts + - name: "Run cypress tests" + uses: cypress-io/github-action@v4.1.1 + with: + browser: chrome + start: npx serve -p 8080 ./element-web/webapp + wait-on: 'http://localhost:8080' + env: + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true + TMPDIR: ${{ runner.temp }} diff --git a/.gitignore b/.gitignore index e4f0112c4..fe5e82797 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ dendrite.yaml # Database files *.db +*.db-journal # Log files *.log* diff --git a/CHANGES.md b/CHANGES.md index f5a82cfe2..f0cd2b4b9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,55 @@ # Changelog +## Dendrite 0.11.1 (2023-02-10) + +**⚠️ DEPRECATION WARNING: This is the last release to have polylith and HTTP API mode. Future releases are monolith only.** + +### Features + +* Dendrite can now be compiled against Go 1.20 +* Initial store and forward support has been added +* A landing page showing that Dendrite is running has been added (contributed by [LukasLJL](https://github.com/LukasLJL)) + +### Fixes + +- `/sync` is now using significantly less database round trips when using Postgres, resulting in faster initial syncs, allowing larger accounts to login again +- Many under the hood pinecone improvements +- Publishing rooms is now possible again + +## Dendrite 0.11.0 (2023-01-20) + +The last three missing federation API Sytests have been fixed - bringing us to 100% server-server Synapse parity, with client-server parity at 93% 🎉 + +### Features + +* Added `/_dendrite/admin/purgeRoom/{roomID}` to clean up the database +* The default room version was updated to 10 (contributed by [FSG-Cat](https://github.com/FSG-Cat)) + +### Fixes + +* An oversight in the `create-config` binary, which now correctly sets the media path if specified (contributed by [BieHDC](https://github.com/BieHDC)) +* The Helm chart now uses the `$.Chart.AppVersion` as the default image version to pull, with the possibility to override it (contributed by [genofire](https://github.com/genofire)) + +## Dendrite 0.10.9 (2023-01-17) + +### Features + +* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore +* Dendrite now has its own Helm chart +* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation) + +### Fixes + +* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts` +* Outgoing presence will now correctly be sent to newly joined hosts +* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable +* Federated backfilling for medium/large rooms has been fixed +* `/login` causing wrong device list updates has been resolved +* `/sync` should now return the correct room summary heroes +* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly +* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug)) +* `/sync` has been optimised to only query state events for history visibility if they are really needed + ## Dendrite 0.10.8 (2022-11-29) ### Features diff --git a/Dockerfile b/Dockerfile index 6da555c04..2487ea5bc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # # base installs required dependencies and runs go mod download to cache dependencies # -FROM --platform=${BUILDPLATFORM} docker.io/golang:1.19-alpine AS base +FROM --platform=${BUILDPLATFORM} docker.io/golang:1.20-alpine AS base RUN apk --update --no-cache add bash build-base curl # @@ -23,44 +23,27 @@ RUN --mount=target=. \ CGO_ENABLED=$([ "$TARGETARCH" = "$USERARCH" ] && echo "1" || echo "0") \ go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/... + # -# The dendrite base image +# Builds the Dendrite image containing all required binaries # -FROM alpine:latest AS dendrite-base +FROM alpine:latest RUN apk --update --no-cache add curl +LABEL org.opencontainers.image.title="Dendrite" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." -# -# Builds the polylith image and only contains the polylith binary -# -FROM dendrite-base AS polylith -LABEL org.opencontainers.image.title="Dendrite (Polylith)" - -COPY --from=build /out/dendrite-polylith-multi /usr/bin/ - -VOLUME /etc/dendrite -WORKDIR /etc/dendrite - -ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] - -# -# Builds the monolith image and contains all required binaries -# -FROM dendrite-base AS monolith -LABEL org.opencontainers.image.title="Dendrite (Monolith)" - COPY --from=build /out/create-account /usr/bin/create-account COPY --from=build /out/generate-config /usr/bin/generate-config COPY --from=build /out/generate-keys /usr/bin/generate-keys -COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server +COPY --from=build /out/dendrite /usr/bin/dendrite VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] +ENTRYPOINT ["/usr/bin/dendrite"] EXPOSE 8008 8448 diff --git a/README.md b/README.md index dfef11bae..ba6960f35 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ This does not mean: - Dendrite is ready for massive homeserver deployments. There is no sharding of microservices (although it is possible to run them on separate machines) and there is no high-availability/clustering support. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. -In the future, we will be able to scale up to gigantic servers (equivalent to `matrix.org`) via polylith mode. If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or join us in: @@ -88,7 +87,7 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it -updates with CI. As of August 2022 we're at around 90% CS API coverage and 95% Federation coverage, though check +updates with CI. As of January 2023, we have 100% server-server parity with Synapse, and the client-server parity is at 93% , though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse servers such as matrix.org reasonably well, although there are still some missing features (like SSO and Third-party ID APIs). diff --git a/appservice/appservice.go b/appservice/appservice.go index 753850de7..5b1b93de2 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -21,14 +21,12 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/matrix-org/gomatrixserverlib" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" - "github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/query" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -36,16 +34,11 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" ) -// AddInternalRoutes registers HTTP handlers for internal API calls -func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI, enableMetrics bool) { - inthttp.AddRoutes(queryAPI, router, enableMetrics) -} - // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { client := &http.Client{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 83c551fea..de9f5aaf1 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -10,12 +10,8 @@ import ( "strings" "testing" - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/appservice/inthttp" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -129,26 +125,10 @@ func TestAppserviceInternalAPI(t *testing.T) { // Create required internal APIs rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) - // Finally execute the tests - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - appservice.AddInternalRoutes(router, asAPI, base.EnableMetrics) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - - asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client: %s", err) - } - runCases(t, asHTTPApi) - }) - - t.Run("Monolith", func(t *testing.T) { - runCases(t, asAPI) - }) + runCases(t, asAPI) }) } diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go deleted file mode 100644 index f7f164877..000000000 --- a/appservice/inthttp/client.go +++ /dev/null @@ -1,84 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// HTTP paths for the internal HTTP APIs -const ( - AppServiceRoomAliasExistsPath = "/appservice/RoomAliasExists" - AppServiceUserIDExistsPath = "/appservice/UserIDExists" - AppServiceLocationsPath = "/appservice/locations" - AppServiceUserPath = "/appservice/users" - AppServiceProtocolsPath = "/appservice/protocols" -) - -// httpAppServiceQueryAPI contains the URL to an appservice query API and a -// reference to a httpClient used to reach it -type httpAppServiceQueryAPI struct { - appserviceURL string - httpClient *http.Client -} - -// NewAppserviceClient creates a AppServiceQueryAPI implemented by talking -// to a HTTP POST API. -// If httpClient is nil an error is returned -func NewAppserviceClient( - appserviceURL string, - httpClient *http.Client, -) (api.AppServiceInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") - } - return &httpAppServiceQueryAPI{appserviceURL, httpClient}, nil -} - -// RoomAliasExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) RoomAliasExists( - ctx context.Context, - request *api.RoomAliasExistsRequest, - response *api.RoomAliasExistsResponse, -) error { - return httputil.CallInternalRPCAPI( - "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath, - h.httpClient, ctx, request, response, - ) -} - -// UserIDExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) UserIDExists( - ctx context.Context, - request *api.UserIDExistsRequest, - response *api.UserIDExistsResponse, -) error { - return httputil.CallInternalRPCAPI( - "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) Locations(ctx context.Context, request *api.LocationRequest, response *api.LocationResponse) error { - return httputil.CallInternalRPCAPI( - "ASLocation", h.appserviceURL+AppServiceLocationsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) User(ctx context.Context, request *api.UserRequest, response *api.UserResponse) error { - return httputil.CallInternalRPCAPI( - "ASUser", h.appserviceURL+AppServiceUserPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) Protocols(ctx context.Context, request *api.ProtocolRequest, response *api.ProtocolResponse) error { - return httputil.CallInternalRPCAPI( - "ASProtocols", h.appserviceURL+AppServiceProtocolsPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go deleted file mode 100644 index b70fad673..000000000 --- a/appservice/inthttp/server.go +++ /dev/null @@ -1,36 +0,0 @@ -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. -func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { - internalAPIMux.Handle( - AppServiceRoomAliasExistsPath, - httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", enableMetrics, a.RoomAliasExists), - ) - - internalAPIMux.Handle( - AppServiceUserIDExistsPath, - httputil.MakeInternalRPCAPI("AppserviceUserIDExists", enableMetrics, a.UserIDExists), - ) - - internalAPIMux.Handle( - AppServiceProtocolsPath, - httputil.MakeInternalRPCAPI("AppserviceProtocols", enableMetrics, a.Protocols), - ) - - internalAPIMux.Handle( - AppServiceLocationsPath, - httputil.MakeInternalRPCAPI("AppserviceLocations", enableMetrics, a.Locations), - ) - - internalAPIMux.Handle( - AppServiceUserPath, - httputil.MakeInternalRPCAPI("AppserviceUser", enableMetrics, a.User), - ) -} diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 81c0f8049..585374738 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -936,4 +936,12 @@ fst Room state after a rejected message event is the same as before fst Room state after a rejected state event is the same as before fpb Federation publicRoom Name/topic keys are correct fed New federated private chats get full presence information (SYN-115) (10 subtests) -dvk Rejects invalid device keys \ No newline at end of file +dvk Rejects invalid device keys +rmv User can create and send/receive messages in a room with version 10 +rmv local user can join room with version 10 +rmv User can invite local user to room with version 10 +rmv remote user can join room with version 10 +rmv User can invite remote user to room with version 10 +rmv Remote user can backfill in a room with version 10 +rmv Can reject invites over federation for rooms with version 10 +rmv Can receive redactions from regular users over federation in room version 10 \ No newline at end of file diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index f44a77488..44e52286f 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -177,19 +176,17 @@ func startup() { if err := cfg.Derive(); err != nil { logrus.Fatalf("Failed to derive values from config: %s", err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg) defer base.Close() // nolint: errcheck rsAPI := roomserver.NewInternalAPI(base) federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, @@ -208,14 +205,12 @@ func startup() { FederationAPI: fedSenderAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), } monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) diff --git a/build/docker/README.md b/build/docker/README.md index 7eb20d88f..b66cb864b 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -5,7 +5,6 @@ These are Docker images for Dendrite! They can be found on Docker Hub: - [matrixdotorg/dendrite-monolith](https://hub.docker.com/r/matrixdotorg/dendrite-monolith) for monolith deployments -- [matrixdotorg/dendrite-polylith](https://hub.docker.com/r/matrixdotorg/dendrite-polylith) for polylith deployments ## Dockerfiles @@ -15,7 +14,6 @@ repository, run: ``` docker build . --target monolith -t matrixdotorg/dendrite-monolith -docker build . --target polylith -t matrixdotorg/dendrite-monolith docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil ``` @@ -25,7 +23,6 @@ docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil There are two sample `docker-compose` files: - `docker-compose.monolith.yml` which runs a monolith Dendrite deployment -- `docker-compose.polylith.yml` which runs a polylith Dendrite deployment ## Configuration @@ -51,9 +48,9 @@ docker run --rm --entrypoint="" \ The key files will now exist in your current working directory, and can be mounted into place. -## Starting Dendrite as a monolith deployment +## Starting Dendrite -Create your config based on the [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) sample configuration file. +Create your config based on the [`dendrite-sample.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.yaml) sample configuration file. Then start the deployment: @@ -61,16 +58,6 @@ Then start the deployment: docker-compose -f docker-compose.monolith.yml up ``` -## Starting Dendrite as a polylith deployment - -Create your config based on the [`dendrite-sample.polylith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.polylith.yaml) sample configuration file. - -Then start the deployment: - -``` -docker-compose -f docker-compose.polylith.yml up -``` - ## Building the images The `build/docker/images-build.sh` script will build the base image, followed by diff --git a/build/docker/docker-compose.polylith.yml b/build/docker/docker-compose.polylith.yml deleted file mode 100644 index de0ab0aa2..000000000 --- a/build/docker/docker-compose.polylith.yml +++ /dev/null @@ -1,143 +0,0 @@ -version: "3.4" -services: - postgres: - hostname: postgres - image: postgres:14 - restart: always - volumes: - - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - # To persist your PostgreSQL databases outside of the Docker image, - # to prevent data loss, modify the following ./path_to path: - - ./path_to/postgresql:/var/lib/postgresql/data - environment: - POSTGRES_PASSWORD: itsasecret - POSTGRES_USER: dendrite - healthcheck: - test: ["CMD-SHELL", "pg_isready -U dendrite"] - interval: 5s - timeout: 5s - retries: 5 - networks: - - internal - - jetstream: - hostname: jetstream - image: nats:latest - command: | - --jetstream - --store_dir /var/lib/nats - --cluster_name Dendrite - volumes: - # To persist your NATS JetStream streams outside of the Docker image, - # prevent data loss, modify the following ./path_to path: - - ./path_to/nats:/var/lib/nats - networks: - - internal - - client_api: - hostname: client_api - image: matrixdotorg/dendrite-polylith:latest - command: clientapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - media_api: - hostname: media_api - image: matrixdotorg/dendrite-polylith:latest - command: mediaapi - volumes: - - ./config:/etc/dendrite - - ./media:/var/dendrite/media - networks: - - internal - restart: unless-stopped - - sync_api: - hostname: sync_api - image: matrixdotorg/dendrite-polylith:latest - command: syncapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - room_server: - hostname: room_server - image: matrixdotorg/dendrite-polylith:latest - command: roomserver - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - federation_api: - hostname: federation_api - image: matrixdotorg/dendrite-polylith:latest - command: federationapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - key_server: - hostname: key_server - image: matrixdotorg/dendrite-polylith:latest - command: keyserver - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - user_api: - hostname: user_api - image: matrixdotorg/dendrite-polylith:latest - command: userapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - appservice_api: - hostname: appservice_api - image: matrixdotorg/dendrite-polylith:latest - command: appservice - volumes: - - ./config:/etc/dendrite - networks: - - internal - depends_on: - - jetstream - - postgres - - room_server - - user_api - restart: unless-stopped - -networks: - internal: - attachable: true diff --git a/build/docker/images-build.sh b/build/docker/images-build.sh index d97a701ed..bf227968f 100755 --- a/build/docker/images-build.sh +++ b/build/docker/images-build.sh @@ -7,6 +7,5 @@ TAG=${1:-latest} echo "Building tag '${TAG}'" docker build . --target monolith -t matrixdotorg/dendrite-monolith:${TAG} -docker build . --target polylith -t matrixdotorg/dendrite-monolith:${TAG} docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone:${TAG} docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil:${TAG} \ No newline at end of file diff --git a/build/docker/images-pull.sh b/build/docker/images-pull.sh index f3f98ce7c..7772ca747 100755 --- a/build/docker/images-pull.sh +++ b/build/docker/images-pull.sh @@ -5,4 +5,3 @@ TAG=${1:-latest} echo "Pulling tag '${TAG}'" docker pull matrixdotorg/dendrite-monolith:${TAG} -docker pull matrixdotorg/dendrite-polylith:${TAG} \ No newline at end of file diff --git a/build/docker/images-push.sh b/build/docker/images-push.sh index 248fdee2b..d166d355a 100755 --- a/build/docker/images-push.sh +++ b/build/docker/images-push.sh @@ -5,4 +5,3 @@ TAG=${1:-latest} echo "Pushing tag '${TAG}'" docker push matrixdotorg/dendrite-monolith:${TAG} -docker push matrixdotorg/dendrite-polylith:${TAG} \ No newline at end of file diff --git a/build/gobind-pinecone/build.sh b/build/gobind-pinecone/build.sh old mode 100644 new mode 100755 index 0f1b1aab9..51844506c --- a/build/gobind-pinecone/build.sh +++ b/build/gobind-pinecone/build.sh @@ -7,7 +7,7 @@ do case "$option" in a) gomobile bind -v -target android -trimpath -ldflags="-s -w" github.com/matrix-org/dendrite/build/gobind-pinecone ;; - i) gomobile bind -v -target ios -trimpath -ldflags="" github.com/matrix-org/dendrite/build/gobind-pinecone ;; + i) gomobile bind -v -target ios -trimpath -ldflags="" -o ~/DendriteBindings/Gobind.xcframework . ;; *) echo "No target specified, specify -a or -i"; exit 1 ;; esac done \ No newline at end of file diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b8f8111d2..16797eec0 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -18,50 +18,25 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "crypto/tls" "encoding/hex" "fmt" - "io" "net" - "net/http" - "os" "path/filepath" "strings" - "sync" - "time" - "go.uber.org/atomic" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/pinecone/types" "github.com/sirupsen/logrus" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" - pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" - "github.com/matrix-org/pinecone/types" _ "golang.org/x/mobile/bind" ) @@ -71,36 +46,37 @@ const ( PeerTypeMulticast = pineconeRouter.PeerTypeMulticast PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + + MaxFrameSize = types.MaxFrameSize ) +// Re-export Conduit in this package for bindings. +type Conduit struct { + conduit.Conduit +} + type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI + logger logrus.Logger + p2pMonolith monolith.P2PMonolith + StorageDirectory string + CacheDirectory string + listener net.Listener } func (m *DendriteMonolith) PublicKey() string { - return m.PineconeRouter.PublicKey().String() + return m.p2pMonolith.Router.PublicKey().String() } func (m *DendriteMonolith) BaseURL() string { - return fmt.Sprintf("http://%s", m.listener.Addr().String()) + return fmt.Sprintf("http://%s", m.p2pMonolith.Addr()) } func (m *DendriteMonolith) PeerCount(peertype int) int { - return m.PineconeRouter.PeerCount(peertype) + return m.p2pMonolith.Router.PeerCount(peertype) } func (m *DendriteMonolith) SessionCount() int { - return len(m.PineconeQUIC.Protocol("matrix").Sessions()) + return len(m.p2pMonolith.Sessions.Protocol(monolith.SessionProtocol).Sessions()) } type InterfaceInfo struct { @@ -142,55 +118,156 @@ func (m *DendriteMonolith) RegisterNetworkCallback(intfCallback InterfaceRetriev } return intfs } - m.PineconeMulticast.RegisterNetworkCallback(callback) + m.p2pMonolith.Multicast.RegisterNetworkCallback(callback) } func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { if enabled { - m.PineconeMulticast.Start() + m.p2pMonolith.Multicast.Start() } else { - m.PineconeMulticast.Stop() + m.p2pMonolith.Multicast.Stop() m.DisconnectType(int(pineconeRouter.PeerTypeMulticast)) } } func (m *DendriteMonolith) SetStaticPeer(uri string) { - m.PineconeManager.RemovePeers() + m.p2pMonolith.ConnManager.RemovePeers() for _, uri := range strings.Split(uri, ",") { - m.PineconeManager.AddPeer(strings.TrimSpace(uri)) + m.p2pMonolith.ConnManager.AddPeer(strings.TrimSpace(uri)) } } +func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { + var nodeKey gomatrixserverlib.ServerName + if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { + hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) + } else { + nodeKey = userID.Domain() + } + } else { + hexKey, decodeErr := hex.DecodeString(nodeID) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) + } else { + nodeKey = gomatrixserverlib.ServerName(nodeID) + } + } + + return nodeKey, nil +} + +func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { + relays := []gomatrixserverlib.ServerName{} + for _, uri := range strings.Split(uris, ",") { + uri = strings.TrimSpace(uri) + if len(uri) == 0 { + continue + } + + nodeKey, err := getServerKeyFromString(uri) + if err != nil { + logrus.Errorf(err.Error()) + continue + } + relays = append(relays, nodeKey) + } + + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + if string(nodeKey) == m.PublicKey() { + logrus.Infof("Setting own relay servers to: %v", relays) + m.p2pMonolith.RelayRetriever.SetRelayServers(relays) + } else { + relay.UpdateNodeRelayServers( + gomatrixserverlib.ServerName(nodeKey), + relays, + m.p2pMonolith.BaseDendrite.Context(), + m.p2pMonolith.GetFederationAPI(), + ) + } +} + +func (m *DendriteMonolith) GetRelayServers(nodeID string) string { + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return "" + } + + relaysString := "" + if string(nodeKey) == m.PublicKey() { + relays := m.p2pMonolith.RelayRetriever.GetRelayServers() + + for i, relay := range relays { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } else { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + response := api.P2PQueryRelayServersResponse{} + err := m.p2pMonolith.GetFederationAPI().P2PQueryRelayServers(m.p2pMonolith.BaseDendrite.Context(), &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + return "" + } + + for i, relay := range response.RelayServers { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } + + return relaysString +} + +func (m *DendriteMonolith) RelayingEnabled() bool { + return m.p2pMonolith.GetRelayAPI().RelayingEnabled() +} + +func (m *DendriteMonolith) SetRelayingEnabled(enabled bool) { + m.p2pMonolith.GetRelayAPI().SetRelayingEnabled(enabled) +} + func (m *DendriteMonolith) DisconnectType(peertype int) { - for _, p := range m.PineconeRouter.Peers() { + for _, p := range m.p2pMonolith.Router.Peers() { if int(peertype) == p.PeerType { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(p.Port), nil) } } } func (m *DendriteMonolith) DisconnectZone(zone string) { - for _, p := range m.PineconeRouter.Peers() { + for _, p := range m.p2pMonolith.Router.Peers() { if zone == p.Zone { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(p.Port), nil) } } } func (m *DendriteMonolith) DisconnectPort(port int) { - m.PineconeRouter.Disconnect(types.SwitchPortID(port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(port), nil) } func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) { l, r := net.Pipe() - conduit := &Conduit{conn: r, port: 0} + newConduit := Conduit{conduit.NewConduit(r, 0)} go func() { - conduit.portMutex.Lock() - defer conduit.portMutex.Unlock() - logrus.Errorf("Attempting authenticated connect") + var port types.SwitchPortID var err error - if conduit.port, err = m.PineconeRouter.Connect( + if port, err = m.p2pMonolith.Router.Connect( l, pineconeRouter.ConnectionZone(zone), pineconeRouter.ConnectionPeerType(peertype), @@ -198,16 +275,17 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) logrus.Errorf("Authenticated connect failed: %s", err) _ = l.Close() _ = r.Close() - _ = conduit.Close() + _ = newConduit.Close() return } - logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port) + newConduit.SetPort(port) + logrus.Infof("Authenticated connect succeeded (port %d)", newConduit.Port()) }() - return conduit, nil + return &newConduit, nil } func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, error) { - pubkey := m.PineconeRouter.PublicKey() + pubkey := m.p2pMonolith.Router.PublicKey() userID := userutil.MakeUserID( localpart, gomatrixserverlib.ServerName(hex.EncodeToString(pubkey[:])), @@ -218,7 +296,7 @@ func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, err Password: password, } userRes := &userapiAPI.PerformAccountCreationResponse{} - if err := m.userAPI.PerformAccountCreation(context.Background(), userReq, userRes); err != nil { + if err := m.p2pMonolith.GetUserAPI().PerformAccountCreation(context.Background(), userReq, userRes); err != nil { return userID, fmt.Errorf("userAPI.PerformAccountCreation: %w", err) } return userID, nil @@ -236,7 +314,7 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e AccessToken: hex.EncodeToString(accessTokenBytes[:n]), } loginRes := &userapiAPI.PerformDeviceCreationResponse{} - if err := m.userAPI.PerformDeviceCreation(context.Background(), loginReq, loginRes); err != nil { + if err := m.p2pMonolith.GetUserAPI().PerformDeviceCreation(context.Background(), loginReq, loginRes); err != nil { return "", fmt.Errorf("userAPI.PerformDeviceCreation: %w", err) } if !loginRes.DeviceCreated { @@ -245,51 +323,10 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e return loginRes.Device.AccessToken, nil } -// nolint:gocyclo func (m *DendriteMonolith) Start() { - var sk ed25519.PrivateKey - var pk ed25519.PublicKey - keyfile := filepath.Join(m.StorageDirectory, "p2p.pem") - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := filepath.Join(m.StorageDirectory, "p2p.key") - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err = test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) - - var err error - m.listener, err = net.Listen("tcp", "localhost:65432") - if err != nil { - panic(err) - } + oldKeyfile := filepath.Join(m.StorageDirectory, "p2p.key") + sk, pk := monolith.GetOrCreateKey(keyfile, oldKeyfile) m.logger = logrus.Logger{ Out: BindLogger{}, @@ -297,226 +334,25 @@ func (m *DendriteMonolith) Start() { m.logger.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{}) - pineconeEventChannel := make(chan pineconeEvents.Event) - m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - m.PineconeRouter.EnableHopLimiting() - m.PineconeRouter.EnableWakeupBroadcasts() - m.PineconeRouter.Subscribe(pineconeEventChannel) - - m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) - m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) - m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) + m.p2pMonolith = monolith.P2PMonolith{} + m.p2pMonolith.SetupPinecone(sk) prefix := hex.EncodeToString(pk) - cfg := &config.Dendrite{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) + cfg := monolith.GenerateDefaultConfig(sk, m.StorageDirectory, m.CacheDirectory, prefix) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.InMemory = false - cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(m.CacheDirectory, prefix)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(m.CacheDirectory, "search")) - if err = cfg.Derive(); err != nil { - panic(err) - } + // NOTE : disabled for now since there is a 64 bit alignment panic on 32 bit systems + // This isn't actually fixed: https://github.com/blevesearch/zapx/pull/147 + cfg.SyncAPI.Fulltext.Enabled = false - base := base.NewBaseDendrite(cfg, "Monolith") - base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck - - federation := conn.CreateFederationClient(base, m.PineconeQUIC) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsAPI := roomserver.NewInternalAPI(base) - - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) - m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(m.userAPI) - - asAPI := appservice.NewInternalAPI(base, m.userAPI, rsAPI) - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(fsAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, m.PineconeQUIC), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: m.userAPI, - KeyAPI: keyAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - monolith.AddAllPublicRoutes(base) - - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) - httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := m.PineconeQUIC.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - h2s := &http2.Server{} - m.httpServer = &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 30 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: h2c.NewHandler(pMux, h2s), - } - - m.processContext = base.ProcessContext - - go func() { - m.logger.Info("Listening on ", cfg.Global.ServerName) - - switch m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Fatal(err) - } - }() - go func() { - logrus.Info("Listening on ", m.listener.Addr()) - - switch http.Serve(m.listener, httpRouter) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Fatal(err) - } - }() - - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: - case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - case pineconeEvents.BandwidthReport: - default: - } - } - }(pineconeEventChannel) + enableRelaying := false + enableMetrics := false + enableWebsockets := false + m.p2pMonolith.SetupDendrite(cfg, 65432, enableRelaying, enableMetrics, enableWebsockets) + m.p2pMonolith.StartMonolith() } func (m *DendriteMonolith) Stop() { - m.processContext.ShutdownDendrite() - _ = m.listener.Close() - m.PineconeMulticast.Stop() - _ = m.PineconeQUIC.Close() - _ = m.PineconeRouter.Close() - m.processContext.WaitForComponentsToFinish() -} - -const MaxFrameSize = types.MaxFrameSize - -type Conduit struct { - closed atomic.Bool - conn net.Conn - port types.SwitchPortID - portMutex sync.Mutex -} - -func (c *Conduit) Port() int { - c.portMutex.Lock() - defer c.portMutex.Unlock() - return int(c.port) -} - -func (c *Conduit) Read(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Read(b) -} - -func (c *Conduit) ReadCopy() ([]byte, error) { - if c.closed.Load() { - return nil, io.EOF - } - var buf [65535 * 2]byte - n, err := c.conn.Read(buf[:]) - if err != nil { - return nil, err - } - return buf[:n], nil -} - -func (c *Conduit) Write(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Write(b) -} - -func (c *Conduit) Close() error { - if c.closed.Load() { - return io.ErrClosedPipe - } - c.closed.Store(true) - return c.conn.Close() + m.p2pMonolith.Stop() } diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go new file mode 100644 index 000000000..434e07ef2 --- /dev/null +++ b/build/gobind-pinecone/monolith_test.go @@ -0,0 +1,152 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gobind + +import ( + "strings" + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +func TestMonolithStarts(t *testing.T) { + monolith := DendriteMonolith{} + monolith.Start() + monolith.PublicKey() + monolith.Stop() +} + +func TestMonolithSetRelayServers(t *testing.T) { + testCases := []struct { + name string + nodeID string + relays string + expectedRelays string + expectSelf bool + }{ + { + name: "assorted valid, invalid, empty & self keys", + nodeID: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: true, + }, + { + name: "invalid node key", + nodeID: "@invalid:notakey", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "", + expectSelf: false, + }, + { + name: "node is self", + nodeID: "self", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: false, + }, + } + + for _, tc := range testCases { + monolith := DendriteMonolith{} + monolith.Start() + + inputRelays := tc.relays + expectedRelays := tc.expectedRelays + if tc.expectSelf { + inputRelays += "," + monolith.PublicKey() + expectedRelays += "," + monolith.PublicKey() + } + nodeID := tc.nodeID + if nodeID == "self" { + nodeID = monolith.PublicKey() + } + + monolith.SetRelayServers(nodeID, inputRelays) + relays := monolith.GetRelayServers(nodeID) + monolith.Stop() + + if !containSameKeys(strings.Split(relays, ","), strings.Split(expectedRelays, ",")) { + t.Fatalf("%s: expected %s got %s", tc.name, expectedRelays, relays) + } + } +} + +func containSameKeys(expected []string, actual []string) bool { + if len(expected) != len(actual) { + return false + } + + for _, expectedKey := range expected { + hasMatch := false + for _, actualKey := range actual { + if actualKey == expectedKey { + hasMatch = true + } + } + + if !hasMatch { + return false + } + } + + return true +} + +func TestParseServerKey(t *testing.T) { + testCases := []struct { + name string + serverKey string + expectedErr bool + expectedKey gomatrixserverlib.ServerName + }{ + { + name: "valid userid as key", + serverKey: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "valid key", + serverKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "invalid userid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + { + name: "invalid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + } + + for _, tc := range testCases { + key, err := getServerKeyFromString(tc.serverKey) + if tc.expectedErr && err == nil { + t.Fatalf("%s: expected an error", tc.name) + } else if !tc.expectedErr && err != nil { + t.Fatalf("%s: didn't expect an error: %s", tc.name, err.Error()) + } + if tc.expectedKey != key { + t.Fatalf("%s: keys not equal. expected: %s got: %s", tc.name, tc.expectedKey, key) + } + } +} diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 8c2d0a006..32af611ae 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -127,8 +126,8 @@ func (m *DendriteMonolith) Start() { cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk @@ -149,7 +148,7 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg) base.ConfigureAdminEndpoints() m.processContext = base.ProcessContext defer base.Close() // nolint: errcheck @@ -165,9 +164,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -186,7 +183,6 @@ func (m *DendriteMonolith) Start() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), @@ -194,7 +190,6 @@ func (m *DendriteMonolith) Start() { monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 3a00fbdf0..70bbe8f95 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -16,8 +16,8 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ - CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index e3fbe1aa8..0b80cfc40 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -19,13 +19,13 @@ WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ - if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\ + if test -f "/runtime/dendrite" && test -f "/runtime/dendrite-cover"; then \n\ echo "Skipping compilation; binaries exist" \n\ exit 0 \n\ fi \n\ cd /dendrite \n\ - go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ - go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\ + go build -v -o /runtime /dendrite/cmd/dendrite \n\ + go test -c -cover -covermode=atomic -o /runtime/dendrite-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite \n\ ' > compile.sh && chmod +x compile.sh # This script runs Dendrite for us. Must be run in the /runtime directory. @@ -35,8 +35,8 @@ RUN echo '\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ - [ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ + [ ${COVER} -eq 1 ] && exec ./dendrite-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ + exec ./dendrite --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 444cb947d..d4b6d3f75 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -34,8 +34,8 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ - CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh index 061bd18eb..52b063d01 100755 --- a/build/scripts/complement-cmd.sh +++ b/build/scripts/complement-cmd.sh @@ -4,19 +4,17 @@ if [[ "${COVER}" -eq 1 ]]; then echo "Running with coverage" - exec /dendrite/dendrite-monolith-server-cover \ + exec /dendrite/dendrite-cover \ --really-enable-open-registration \ --tls-cert server.crt \ --tls-key server.key \ --config dendrite.yaml \ - -api=${API:-0} \ - --test.coverprofile=integrationcover.log + --test.coverprofile=complementcover.log else echo "Not running with coverage" - exec /dendrite/dendrite-monolith-server \ + exec /dendrite/dendrite \ --really-enable-open-registration \ --tls-cert server.crt \ --tls-key server.key \ - --config dendrite.yaml \ - -api=${API:-0} + --config dendrite.yaml fi diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 0d973f350..300d3a88a 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -7,9 +7,11 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -37,11 +39,9 @@ func TestAdminResetPassword(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for changing the password/login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -112,6 +112,7 @@ func TestAdminResetPassword(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) if tc.requestOpt != nil { @@ -132,3 +133,98 @@ func TestAdminResetPassword(t *testing.T) { } }) } + +func TestPurgeRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + roomID string + wantOK bool + }{ + {name: "Can purge existing room", wantOK: true, roomID: room.ID}, + {name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"}, + {name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + + }) +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 2d17e0928..e9985d43f 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,7 @@ package clientapi import ( + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" @@ -23,11 +24,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/routing" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. @@ -40,7 +39,6 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { cfg := &base.Cfg.ClientAPI @@ -61,7 +59,7 @@ func AddPublicRoutes( base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, + syncProducer, transactionsCache, fsAPI, extRoomsProvider, mscCfg, natsClient, ) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index dbd913376..a01f6b944 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,6 +1,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -15,14 +16,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -55,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -98,7 +98,38 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} + if err := rsAPI.PerformAdminPurgeRoom( + context.Background(), + &roomserverAPI.PerformAdminPurgeRoomRequest{ + RoomID: roomID, + }, + res, + ); err != nil { + return util.ErrorResponse(err) + } + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + +func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -118,8 +149,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} - if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + accAvailableResp := &api.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{ Localpart: localpart, ServerName: serverName, }, accAvailableResp); err != nil { @@ -154,13 +185,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap return *internal.PasswordResponse(err) } - updateReq := &userapi.PerformPasswordUpdateRequest{ + updateReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, Password: request.Password, LogoutDevices: true, } - updateRes := &userapi.PerformPasswordUpdateResponse{} + updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -177,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } -func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse { _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") @@ -223,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien } } -func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go index 0d77f9a01..534581bdd 100644 --- a/clientapi/routing/auth_fallback_test.go +++ b/clientapi/routing/auth_fallback_test.go @@ -22,7 +22,7 @@ func Test_AuthFallback(t *testing.T) { for _, wantErr := range []bool{false, true} { t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { // Set the defaults for each test - base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true}) + base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" @@ -33,7 +33,7 @@ func Test_AuthFallback(t *testing.T) { base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" } cfgErrs := &config.ConfigErrors{} - base.Cfg.ClientAPI.Verify(cfgErrs, true) + base.Cfg.ClientAPI.Verify(cfgErrs) if len(*cfgErrs) > 0 { t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) } diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 9e8208e6d..1450ef4bd 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -10,7 +10,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -29,8 +28,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 2570db09c..267ba1dc5 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -21,9 +21,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -34,8 +33,8 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, - keyserverAPI api.ClientKeyAPI, device *userapi.Device, - accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + keyserverAPI api.ClientKeyAPI, device *api.Device, + accountAPI api.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys( } } -func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { uploadReq := &api.PerformUploadDeviceSignaturesRequest{} uploadRes := &api.PerformUploadDeviceSignaturesResponse{} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 0c12b1117..3d60fcc3a 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -23,8 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) type uploadKeysRequest struct { @@ -32,7 +31,7 @@ type uploadKeysRequest struct { OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } -func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r uploadKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return timeout } -func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index d429d7f8c..b72db9d8b 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -39,12 +38,10 @@ func TestLogin(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for /login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil) // Create password password := util.RandomString(8) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index bccc1b79b..651e3d3d6 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -201,8 +200,8 @@ func TestValidationOfApplicationServices(t *testing.T) { // Set up a config fakeConfig := &config.Dendrite{} fakeConfig.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) fakeConfig.Global.ServerName = "localhost" fakeConfig.ClientAPI.Derived.ApplicationServices = []config.ApplicationService{fakeApplicationService} @@ -409,9 +408,7 @@ func Test_register(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -582,9 +579,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) { base.Cfg.Global.ServerName = "server" rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) deviceName, deviceID := "deviceName", "deviceID" expectedDisplayName := "DisplayName" response := completeRegistration( @@ -623,9 +618,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) expectedDisplayName := "rabbit" jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 09c2cd02f..028d02e97 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -18,9 +18,11 @@ import ( "context" "net/http" "strings" + "sync" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -36,11 +38,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -60,7 +60,6 @@ func Setup( syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { @@ -165,6 +164,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", + httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminPurgeRoom(req, cfg, device, rsAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) @@ -185,25 +190,31 @@ func Setup( dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminMarkAsStale(req, cfg, keyAPI) + return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) - if err != nil { - logrus.WithError(err).Fatal("unable to get account for sending sending server notices") - } + var serverNotificationSender *userapi.Device + var err error + notificationSenderOnce := &sync.Once{} synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r } - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + var vars map[string]string + vars, err = httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -219,6 +230,12 @@ func Setup( synapseAdminRouter.Handle("/admin/v1/send_server_notice", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -1353,11 +1370,11 @@ func Setup( // Cross-signing device keys postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) + return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) }) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceSignatures(req, keyAPI, device) + return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) @@ -1369,22 +1386,22 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req, keyAPI, device) + return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return ClaimKeys(req, keyAPI) + return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", diff --git a/cmd/dendrite-demo-pinecone/README.md b/cmd/dendrite-demo-pinecone/README.md index d6dd95905..5cacd0924 100644 --- a/cmd/dendrite-demo-pinecone/README.md +++ b/cmd/dendrite-demo-pinecone/README.md @@ -24,3 +24,42 @@ Then point your favourite Matrix client to the homeserver URL`http://localhost: If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer. Once logged in, you should be able to open the room directory or join a room by its ID. + +## Store & Forward Relays + +To test out the store & forward relay functionality, you need a minimum of 3 instances. +One instance will act as the relay, and the other two instances will be the users trying to communicate. +Then you can send messages between the two nodes and watch as the relay is used if the receiving node is offline. + +### Launching the Nodes + +Relay Server: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir relay/ -listen "[::]:49000" +``` + +Node 1: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-1/ -peer "[::]:49000" -port 8007 +``` + +Node 2: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-2/ -peer "[::]:49000" -port 8009 +``` + +### Database Setup + +At the moment, the database must be manually configured. +For both `Node 1` and `Node 2` add the following entries to their respective `relay_server` table in the federationapi database: +``` +server_name: {node_1_public_key}, relay_server_name: {relay_public_key} +server_name: {node_2_public_key}, relay_server_name: {relay_public_key} +``` + +After editing the database you will need to relaunch the nodes for the changes to be picked up by dendrite. + +### Testing + +Now you can run two separate instances of element and connect them to `Node 1` and `Node 2`. +You can shutdown one of the nodes and continue sending messages. If you wait long enough, the message will be sent to the relay server. (you can see this in the log output of the relay server) diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit.go b/cmd/dendrite-demo-pinecone/conduit/conduit.go new file mode 100644 index 000000000..be139c19c --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit.go @@ -0,0 +1,84 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "io" + "net" + "sync" + + "github.com/matrix-org/pinecone/types" + "go.uber.org/atomic" +) + +type Conduit struct { + closed atomic.Bool + conn net.Conn + portMutex sync.Mutex + port types.SwitchPortID +} + +func NewConduit(conn net.Conn, port int) Conduit { + return Conduit{ + conn: conn, + port: types.SwitchPortID(port), + } +} + +func (c *Conduit) Port() int { + c.portMutex.Lock() + defer c.portMutex.Unlock() + return int(c.port) +} + +func (c *Conduit) SetPort(port types.SwitchPortID) { + c.portMutex.Lock() + defer c.portMutex.Unlock() + c.port = port +} + +func (c *Conduit) Read(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Read(b) +} + +func (c *Conduit) ReadCopy() ([]byte, error) { + if c.closed.Load() { + return nil, io.EOF + } + var buf [65535 * 2]byte + n, err := c.conn.Read(buf[:]) + if err != nil { + return nil, err + } + return buf[:n], nil +} + +func (c *Conduit) Write(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Write(b) +} + +func (c *Conduit) Close() error { + if c.closed.Load() { + return io.ErrClosedPipe + } + c.closed.Store(true) + return c.conn.Close() +} diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit_test.go b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go new file mode 100644 index 000000000..d8cd3133f --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +} diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f627b41d..7706a73f5 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -15,58 +15,35 @@ package main import ( - "context" "crypto/ed25519" - "crypto/tls" "encoding/hex" "flag" "fmt" "net" - "net/http" "os" "path/filepath" "strings" - "time" - "github.com/gorilla/mux" - "github.com/gorilla/websocket" - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - - pineconeConnections "github.com/matrix-org/pinecone/connections" - pineconeMulticast "github.com/matrix-org/pinecone/multicast" - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" - "github.com/sirupsen/logrus" + + pineconeRouter "github.com/matrix-org/pinecone/router" ) var ( - instanceName = flag.String("name", "dendrite-p2p-pinecone", "the name of this P2P demo instance") - instancePort = flag.Int("port", 8008, "the port that the client API will listen on") - instancePeer = flag.String("peer", "", "the static Pinecone peers to connect to, comma separated-list") - instanceListen = flag.String("listen", ":0", "the port Pinecone peers can connect to") - instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") + instanceName = flag.String("name", "dendrite-p2p-pinecone", "the name of this P2P demo instance") + instancePort = flag.Int("port", 8008, "the port that the client API will listen on") + instancePeer = flag.String("peer", "", "the static Pinecone peers to connect to, comma separated-list") + instanceListen = flag.String("listen", ":0", "the port Pinecone peers can connect to") + instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") + instanceRelayingEnabled = flag.Bool("relay", false, "whether to enable store & forward relaying for other nodes") ) -// nolint:gocyclo func main() { flag.Parse() internal.SetupPprof() @@ -83,7 +60,7 @@ func main() { } } - cfg := &config.Dendrite{} + var cfg *config.Dendrite // use custom config if config flag is set if configFlagSet { @@ -92,88 +69,30 @@ func main() { pk = sk.Public().(ed25519.PublicKey) } else { keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := *instanceName + ".key" - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err := test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - var err error - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) - - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) - cfg.Global.PrivateKey = sk - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(*instanceDir, *instanceName))) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.MediaAPI.BasePath = config.Path(*instanceDir) - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir) - if err := cfg.Derive(); err != nil { - panic(err) - } + oldKeyfile := *instanceName + ".key" + sk, pk = monolith.GetOrCreateKey(keyfile, oldKeyfile) + cfg = monolith.GenerateDefaultConfig(sk, *instanceDir, *instanceDir, *instanceName) } cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - base := base.NewBaseDendrite(cfg, "Monolith") - base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck + p2pMonolith := monolith.P2PMonolith{} + p2pMonolith.SetupPinecone(sk) + p2pMonolith.Multicast.Start() - pineconeEventChannel := make(chan pineconeEvents.Event) - pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - pRouter.EnableHopLimiting() - pRouter.EnableWakeupBroadcasts() - pRouter.Subscribe(pineconeEventChannel) - - pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) - pManager := pineconeConnections.NewConnectionManager(pRouter, nil) - pMulticast.Start() if instancePeer != nil && *instancePeer != "" { for _, peer := range strings.Split(*instancePeer, ",") { - pManager.AddPeer(strings.Trim(peer, " \t\r\n")) + p2pMonolith.ConnManager.AddPeer(strings.Trim(peer, " \t\r\n")) } } + enableMetrics := true + enableWebsockets := true + p2pMonolith.SetupDendrite(cfg, *instancePort, *instanceRelayingEnabled, enableMetrics, enableWebsockets) + p2pMonolith.StartMonolith() + p2pMonolith.WaitForShutdown() + go func() { listener, err := net.Listen("tcp", *instanceListen) if err != nil { @@ -189,7 +108,7 @@ func main() { continue } - port, err := pRouter.Connect( + port, err := p2pMonolith.Router.Connect( conn, pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), ) @@ -201,135 +120,4 @@ func main() { fmt.Println("Inbound connection", conn.RemoteAddr(), "is connected to port", port) } }() - - federation := conn.CreateFederationClient(base, pQUIC) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsComponent := roomserver.NewInternalAPI(base) - rsAPI := rsComponent - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) - - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - - rsComponent.SetFederationAPI(fsAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, pQUIC), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - monolith.AddAllPublicRoutes(base) - - wsUpgrader := websocket.Upgrader{ - CheckOrigin: func(_ *http.Request) bool { - return true - }, - } - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) - httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - c, err := wsUpgrader.Upgrade(w, r, nil) - if err != nil { - logrus.WithError(err).Error("Failed to upgrade WebSocket connection") - return - } - conn := conn.WrapWebSocketConn(c) - if _, err = pRouter.Connect( - conn, - pineconeRouter.ConnectionZone("websocket"), - pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), - ); err != nil { - logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") - } - }) - httpRouter.HandleFunc("/pinecone", pRouter.ManholeHandler) - embed.Embed(httpRouter, *instancePort, "Pinecone Demo") - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := pQUIC.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - httpServer := &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 60 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: pMux, - } - - go func() { - pubkey := pRouter.PublicKey() - logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) - logrus.Fatal(httpServer.Serve(pQUIC.Protocol("matrix"))) - }() - go func() { - httpBindAddr := fmt.Sprintf(":%d", *instancePort) - logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) - }() - - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: - case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - case pineconeEvents.BandwidthReport: - default: - } - } - }(pineconeEventChannel) - - base.WaitForShutdown() } diff --git a/cmd/dendrite-demo-pinecone/monolith/keys.go b/cmd/dendrite-demo-pinecone/monolith/keys.go new file mode 100644 index 000000000..637f24a43 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/monolith/keys.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package monolith + +import ( + "crypto/ed25519" + "os" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func GetOrCreateKey(keyfile string, oldKeyfile string) (ed25519.PrivateKey, ed25519.PublicKey) { + var sk ed25519.PrivateKey + var pk ed25519.PublicKey + + if _, err := os.Stat(keyfile); os.IsNotExist(err) { + if _, err = os.Stat(oldKeyfile); os.IsNotExist(err) { + if err = test.NewMatrixKey(keyfile); err != nil { + panic("failed to generate a new PEM key: " + err.Error()) + } + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + } else { + if sk, err = os.ReadFile(oldKeyfile); err != nil { + panic("failed to read the old private key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + if err = test.SaveMatrixKey(keyfile, sk); err != nil { + panic("failed to convert the private key to PEM format: " + err.Error()) + } + } + } else { + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + } + + pk = sk.Public().(ed25519.PublicKey) + + return sk, pk +} diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go new file mode 100644 index 000000000..ea8e985c7 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -0,0 +1,387 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package monolith + +import ( + "context" + "crypto/ed25519" + "crypto/tls" + "encoding/hex" + "fmt" + "net" + "net/http" + "path/filepath" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/federationapi" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + + pineconeConnections "github.com/matrix-org/pinecone/connections" + pineconeMulticast "github.com/matrix-org/pinecone/multicast" + pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" + pineconeSessions "github.com/matrix-org/pinecone/sessions" +) + +const SessionProtocol = "matrix" + +type P2PMonolith struct { + BaseDendrite *base.BaseDendrite + Sessions *pineconeSessions.Sessions + Multicast *pineconeMulticast.Multicast + ConnManager *pineconeConnections.ConnectionManager + Router *pineconeRouter.Router + EventChannel chan pineconeEvents.Event + RelayRetriever relay.RelayServerRetriever + + dendrite setup.Monolith + port int + httpMux *mux.Router + pineconeMux *mux.Router + httpServer *http.Server + listener net.Listener + httpListenAddr string + stopHandlingEvents chan bool +} + +func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, cacheDir string, dbPrefix string) *config.Dendrite { + cfg := config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + SingleDatabase: true, + }) + cfg.Global.PrivateKey = sk + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(cacheDir, dbPrefix))) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(storageDir, dbPrefix))) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(storageDir, dbPrefix))) + cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(storageDir, dbPrefix))) + cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(storageDir, dbPrefix))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(storageDir, dbPrefix))) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + cfg.MediaAPI.BasePath = config.Path(filepath.Join(cacheDir, "media")) + cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(cacheDir, "media")) + cfg.SyncAPI.Fulltext.Enabled = true + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(cacheDir, "search")) + if err := cfg.Derive(); err != nil { + panic(err) + } + + return &cfg +} + +func (p *P2PMonolith) SetupPinecone(sk ed25519.PrivateKey) { + p.EventChannel = make(chan pineconeEvents.Event) + p.Router = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + p.Router.EnableHopLimiting() + p.Router.EnableWakeupBroadcasts() + p.Router.Subscribe(p.EventChannel) + + p.Sessions = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), p.Router, []string{SessionProtocol}) + p.Multicast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), p.Router) + p.ConnManager = pineconeConnections.NewConnectionManager(p.Router, nil) +} + +func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelaying bool, enableMetrics bool, enableWebsockets bool) { + if enableMetrics { + p.BaseDendrite = base.NewBaseDendrite(cfg) + } else { + p.BaseDendrite = base.NewBaseDendrite(cfg, base.DisableMetrics) + } + p.port = port + p.BaseDendrite.ConfigureAdminEndpoints() + + federation := conn.CreateFederationClient(p.BaseDendrite, p.Sessions) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + + rsComponent := roomserver.NewInternalAPI(p.BaseDendrite) + rsAPI := rsComponent + fsAPI := federationapi.NewInternalAPI( + p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true, + ) + + userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation) + + asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI) + + rsComponent.SetFederationAPI(fsAPI, keyRing) + + userProvider := users.NewPineconeUserProvider(p.Router, p.Sessions, userAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(p.Router, p.Sessions, fsAPI, federation) + + js, _ := p.BaseDendrite.NATS.Prepare(p.BaseDendrite.ProcessContext, &p.BaseDendrite.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &p.BaseDendrite.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(p.BaseDendrite, federation, rsAPI, keyRing, producer, enableRelaying) + logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) + + p.dendrite = setup.Monolith{ + Config: p.BaseDendrite.Cfg, + Client: conn.CreateClient(p.BaseDendrite, p.Sessions), + FedClient: federation, + KeyRing: keyRing, + + AppserviceAPI: asAPI, + FederationAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + RelayAPI: relayAPI, + ExtPublicRoomsProvider: roomProvider, + ExtUserDirectoryProvider: userProvider, + } + p.dendrite.AddAllPublicRoutes(p.BaseDendrite) + + p.setupHttpServers(userProvider, enableWebsockets) +} + +func (p *P2PMonolith) GetFederationAPI() federationAPI.FederationInternalAPI { + return p.dendrite.FederationAPI +} + +func (p *P2PMonolith) GetRelayAPI() relayAPI.RelayInternalAPI { + return p.dendrite.RelayAPI +} + +func (p *P2PMonolith) GetUserAPI() userAPI.UserInternalAPI { + return p.dendrite.UserAPI +} + +func (p *P2PMonolith) StartMonolith() { + p.startHTTPServers() + p.startEventHandler() +} + +func (p *P2PMonolith) Stop() { + logrus.Info("Stopping monolith") + _ = p.BaseDendrite.Close() + p.WaitForShutdown() + logrus.Info("Stopped monolith") +} + +func (p *P2PMonolith) WaitForShutdown() { + p.BaseDendrite.WaitForShutdown() + p.closeAllResources() +} + +func (p *P2PMonolith) closeAllResources() { + logrus.Info("Closing monolith resources") + if p.httpServer != nil { + _ = p.httpServer.Shutdown(context.Background()) + } + + select { + case p.stopHandlingEvents <- true: + default: + } + + if p.listener != nil { + _ = p.listener.Close() + } + + if p.Multicast != nil { + p.Multicast.Stop() + } + + if p.Sessions != nil { + _ = p.Sessions.Close() + } + + if p.Router != nil { + _ = p.Router.Close() + } + logrus.Info("Monolith resources closed") +} + +func (p *P2PMonolith) Addr() string { + return p.httpListenAddr +} + +func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, enableWebsockets bool) { + p.httpMux = mux.NewRouter().SkipClean(true).UseEncodedPath() + p.httpMux.PathPrefix(httputil.PublicClientPathPrefix).Handler(p.BaseDendrite.PublicClientAPIMux) + p.httpMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) + p.httpMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(p.BaseDendrite.DendriteAdminMux) + p.httpMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(p.BaseDendrite.SynapseAdminMux) + + if enableWebsockets { + wsUpgrader := websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { + return true + }, + } + p.httpMux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + c, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + logrus.WithError(err).Error("Failed to upgrade WebSocket connection") + return + } + conn := conn.WrapWebSocketConn(c) + if _, err = p.Router.Connect( + conn, + pineconeRouter.ConnectionZone("websocket"), + pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), + ); err != nil { + logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") + } + }) + } + + p.httpMux.HandleFunc("/pinecone", p.Router.ManholeHandler) + + if enableWebsockets { + embed.Embed(p.httpMux, p.port, "Pinecone Demo") + } + + p.pineconeMux = mux.NewRouter().SkipClean(true).UseEncodedPath() + p.pineconeMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) + p.pineconeMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(p.BaseDendrite.PublicFederationAPIMux) + p.pineconeMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) + + pHTTP := p.Sessions.Protocol(SessionProtocol).HTTP() + pHTTP.Mux().Handle(users.PublicURL, p.pineconeMux) + pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, p.pineconeMux) + pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, p.pineconeMux) +} + +func (p *P2PMonolith) startHTTPServers() { + go func() { + // Build both ends of a HTTP multiplex. + p.httpServer = &http.Server{ + Addr: ":0", + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return context.Background() + }, + Handler: p.pineconeMux, + } + + pubkey := p.Router.PublicKey() + pubkeyString := hex.EncodeToString(pubkey[:]) + logrus.Info("Listening on ", pubkeyString) + + switch p.httpServer.Serve(p.Sessions.Protocol(SessionProtocol)) { + case net.ErrClosed, http.ErrServerClosed: + logrus.Info("Stopped listening on ", pubkeyString) + default: + logrus.Error("Stopped listening on ", pubkeyString) + } + logrus.Info("Stopped goroutine listening on ", pubkeyString) + }() + + p.httpListenAddr = fmt.Sprintf(":%d", p.port) + go func() { + logrus.Info("Listening on ", p.httpListenAddr) + switch http.ListenAndServe(p.httpListenAddr, p.httpMux) { + case net.ErrClosed, http.ErrServerClosed: + logrus.Info("Stopped listening on ", p.httpListenAddr) + default: + logrus.Error("Stopped listening on ", p.httpListenAddr) + } + logrus.Info("Stopped goroutine listening on ", p.httpListenAddr) + }() +} + +func (p *P2PMonolith) startEventHandler() { + p.stopHandlingEvents = make(chan bool) + stopRelayServerSync := make(chan bool) + eLog := logrus.WithField("pinecone", "events") + p.RelayRetriever = relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(p.Router.PublicKey().String()), + p.dendrite.FederationAPI, + p.dendrite.RelayAPI, + stopRelayServerSync, + ) + p.RelayRetriever.InitializeRelayServers(eLog) + + go func(ch <-chan pineconeEvents.Event) { + for { + select { + case event := <-ch: + switch e := event.(type) { + case pineconeEvents.PeerAdded: + p.RelayRetriever.StartSync() + case pineconeEvents.PeerRemoved: + if p.RelayRetriever.IsRunning() && p.Router.TotalPeerCount() == 0 { + // NOTE: Don't block on channel + select { + case stopRelayServerSync <- true: + default: + } + } + case pineconeEvents.BroadcastReceived: + // eLog.Info("Broadcast received from: ", e.PeerID) + + req := &federationAPI.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &federationAPI.PerformWakeupServersResponse{} + if err := p.dendrite.FederationAPI.PerformWakeupServers(p.BaseDendrite.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + } + case <-p.stopHandlingEvents: + logrus.Info("Stopping processing pinecone events") + // NOTE: Don't block on channel + select { + case stopRelayServerSync <- true: + default: + } + logrus.Info("Stopped processing pinecone events") + return + } + } + }(p.EventChannel) +} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever.go b/cmd/dendrite-demo-pinecone/relay/retriever.go new file mode 100644 index 000000000..6b34f6416 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever.go @@ -0,0 +1,238 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "sync" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" +) + +const ( + relayServerRetryInterval = time.Second * 30 +) + +type RelayServerRetriever struct { + ctx context.Context + serverName gomatrixserverlib.ServerName + federationAPI federationAPI.FederationInternalAPI + relayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool + quit chan bool +} + +func NewRelayServerRetriever( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + federationAPI federationAPI.FederationInternalAPI, + relayAPI relayServerAPI.RelayInternalAPI, + quit chan bool, +) RelayServerRetriever { + return RelayServerRetriever{ + ctx: ctx, + serverName: serverName, + federationAPI: federationAPI, + relayAPI: relayAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + running: *atomic.NewBool(false), + quit: quit, + } +} + +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := federationAPI.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.serverName)} + response := federationAPI.P2PQueryRelayServersResponse{} + err := r.federationAPI.P2PQueryRelayServers(r.ctx, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for _, server := range response.RelayServers { + r.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + UpdateNodeRelayServers(r.serverName, servers, r.ctx, r.federationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range r.relayServersQueried { + result[server] = queried + } + return result +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) IsRunning() bool { + return r.running.Load() +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") + return + } + r.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + logrus.Info("Stopped relay server retriever") + return + case <-t.C: + } + } +} + +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.serverName), false) + if err != nil { + return + } + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.relayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} + +func UpdateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI federationAPI.FederationInternalAPI, +) { + // Get the current relay list + request := federationAPI.P2PQueryRelayServersRequest{Server: node} + response := federationAPI.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := federationAPI.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := federationAPI.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := federationAPI.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := federationAPI.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever_test.go b/cmd/dendrite-demo-pinecone/relay/retriever_test.go new file mode 100644 index 000000000..6c4c3a529 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "testing" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + federationAPI.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers( + ctx context.Context, + req *federationAPI.P2PQueryRelayServersRequest, + res *federationAPI.P2PQueryRelayServersResponse, +) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 3ea4a08b0..d759c6a73 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -39,7 +39,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -117,8 +116,8 @@ func main() { cfg = setup.ParseFlags(true) } else { cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.PrivateKey = sk cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(*instanceDir, *instanceName)) @@ -143,7 +142,7 @@ func main() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg) base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck @@ -157,16 +156,11 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsComponent := roomserver.NewInternalAPI( + rsAPI := roomserver.NewInternalAPI( base, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent) - - rsAPI := rsComponent - - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -174,7 +168,7 @@ func main() { base, federation, rsAPI, base.Caches, keyRing, true, ) - rsComponent.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(fsAPI, keyRing) monolith := setup.Monolith{ Config: base.Cfg, @@ -186,7 +180,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), @@ -197,7 +190,6 @@ func main() { } httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go deleted file mode 100644 index 6836b6426..000000000 --- a/cmd/dendrite-monolith-server/main.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "os" - - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs" - "github.com/matrix-org/dendrite/userapi" - uapi "github.com/matrix-org/dendrite/userapi/api" -) - -var ( - httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") - httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") - apiBindAddr = flag.String("api-bind-address", "localhost:18008", "The HTTP listening port for the internal HTTP APIs (if -api is enabled)") - certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") - keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") - enableHTTPAPIs = flag.Bool("api", false, "Use HTTP APIs instead of short-circuiting (warning: exposes API endpoints!)") - traceInternal = os.Getenv("DENDRITE_TRACE_INTERNAL") == "1" -) - -func main() { - cfg := setup.ParseFlags(true) - httpAddr := config.HTTPAddress("http://" + *httpBindAddr) - httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) - httpAPIAddr := httpAddr - options := []basepkg.BaseDendriteOptions{} - if *enableHTTPAPIs { - logrus.Warnf("DANGER! The -api option is enabled, exposing internal APIs on %q!", *apiBindAddr) - httpAPIAddr = config.HTTPAddress("http://" + *apiBindAddr) - // If the HTTP APIs are enabled then we need to update the Listen - // statements in the configuration so that we know where to find - // the API endpoints. They'll listen on the same port as the monolith - // itself. - cfg.AppServiceAPI.InternalAPI.Connect = httpAPIAddr - cfg.ClientAPI.InternalAPI.Connect = httpAPIAddr - cfg.FederationAPI.InternalAPI.Connect = httpAPIAddr - cfg.KeyServer.InternalAPI.Connect = httpAPIAddr - cfg.MediaAPI.InternalAPI.Connect = httpAPIAddr - cfg.RoomServer.InternalAPI.Connect = httpAPIAddr - cfg.SyncAPI.InternalAPI.Connect = httpAPIAddr - cfg.UserAPI.InternalAPI.Connect = httpAPIAddr - options = append(options, basepkg.UseHTTPAPIs) - } - - base := basepkg.NewBaseDendrite(cfg, "Monolith", options...) - defer base.Close() // nolint: errcheck - - federation := base.CreateFederationClient() - - rsImpl := roomserver.NewInternalAPI(base) - // call functions directly on the impl unless running in HTTP mode - rsAPI := rsImpl - if base.UseHTTPAPIs { - roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl, base.EnableMetrics) - rsAPI = base.RoomserverHTTPClient() - } - if traceInternal { - rsAPI = &api.RoomserverInternalAPITrace{ - Impl: rsAPI, - } - } - - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, nil, false, - ) - fsImplAPI := fsAPI - if base.UseHTTPAPIs { - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) - fsAPI = base.FederationAPIHTTPClient() - } - keyRing := fsAPI.KeyRing() - - keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) - keyAPI := keyImpl - if base.UseHTTPAPIs { - keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) - keyAPI = base.KeyServerHTTPClient() - } - - pgClient := base.PushGatewayHTTPClient() - userImpl := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) - userAPI := userImpl - if base.UseHTTPAPIs { - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) - userAPI = base.UserAPIClient() - } - if traceInternal { - userAPI = &uapi.UserInternalAPITrace{ - Impl: userAPI, - } - } - - // TODO: This should use userAPI, not userImpl, but the appservice setup races with - // the listeners and panics at startup if it tries to create appservice accounts - // before the listeners are up. - asAPI := appservice.NewInternalAPI(base, userImpl, rsAPI) - if base.UseHTTPAPIs { - appservice.AddInternalRoutes(base.InternalAPIMux, asAPI, base.EnableMetrics) - asAPI = base.AppserviceHTTPClient() - } - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this - // dependency. Other components also need updating after their dependencies are up. - rsImpl.SetFederationAPI(fsAPI, keyRing) - rsImpl.SetAppserviceAPI(asAPI) - rsImpl.SetUserAPI(userAPI) - keyImpl.SetUserAPI(userAPI) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: base.CreateClient(), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - // always use the concrete impl here even in -http mode because adding public routes - // must be done on the concrete impl not an HTTP client else fedapi will call itself - FederationAPI: fsImplAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - } - monolith.AddAllPublicRoutes(base) - - if len(base.Cfg.MSCs.MSCs) > 0 { - if err := mscs.Enable(base, &monolith); err != nil { - logrus.WithError(err).Fatalf("Failed to enable MSCs") - } - } - - // Expose the matrix APIs directly rather than putting them under a /api path. - go func() { - base.SetupAndServeHTTP( - httpAPIAddr, // internal API - httpAddr, // external API - nil, nil, // TLS settings - ) - }() - // Handle HTTPS if certificate and key are provided - if *certFile != "" && *keyFile != "" { - go func() { - base.SetupAndServeHTTP( - basepkg.NoListener, // internal API - httpsAddr, // external API - certFile, keyFile, // TLS settings - ) - }() - } - - // We want to block forever to let the HTTP and HTTPS handler serve the APIs - base.WaitForShutdown() -} diff --git a/cmd/dendrite-polylith-multi/main.go b/cmd/dendrite-polylith-multi/main.go deleted file mode 100644 index c6a560b19..000000000 --- a/cmd/dendrite-polylith-multi/main.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "os" - "strings" - - "github.com/matrix-org/dendrite/cmd/dendrite-polylith-multi/personalities" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/sirupsen/logrus" -) - -type entrypoint func(base *base.BaseDendrite, cfg *config.Dendrite) - -func main() { - cfg := setup.ParseFlags(false) - - component := "" - if flag.NFlag() > 0 { - component = flag.Arg(0) // ./dendrite-polylith-multi --config=... clientapi - } else if len(os.Args) > 1 { - component = os.Args[1] // ./dendrite-polylith-multi clientapi - } - - components := map[string]entrypoint{ - "appservice": personalities.Appservice, - "clientapi": personalities.ClientAPI, - "federationapi": personalities.FederationAPI, - "keyserver": personalities.KeyServer, - "mediaapi": personalities.MediaAPI, - "roomserver": personalities.RoomServer, - "syncapi": personalities.SyncAPI, - "userapi": personalities.UserAPI, - } - - start, ok := components[component] - if !ok { - if component == "" { - logrus.Errorf("No component specified") - logrus.Info("The first argument on the command line must be the name of the component to run") - } else { - logrus.Errorf("Unknown component %q specified", component) - } - - var list []string - for c := range components { - list = append(list, c) - } - logrus.Infof("Valid components: %s", strings.Join(list, ", ")) - - os.Exit(1) - } - - logrus.Infof("Starting %q component", component) - - base := base.NewBaseDendrite(cfg, component, base.PolylithMode) // TODO - defer base.Close() // nolint: errcheck - - go start(base, cfg) - base.WaitForShutdown() -} diff --git a/cmd/dendrite-polylith-multi/personalities/appservice.go b/cmd/dendrite-polylith-multi/personalities/appservice.go deleted file mode 100644 index 0547d57f0..000000000 --- a/cmd/dendrite-polylith-multi/personalities/appservice.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/setup/base" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func Appservice(base *base.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - rsAPI := base.RoomserverHTTPClient() - - intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - appservice.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) - - base.SetupAndServeHTTP( - base.Cfg.AppServiceAPI.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/clientapi.go b/cmd/dendrite-polylith-multi/personalities/clientapi.go deleted file mode 100644 index a5d69d07c..000000000 --- a/cmd/dendrite-polylith-multi/personalities/clientapi.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/internal/transactions" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - federation := base.CreateFederationClient() - - asQuery := base.AppserviceHTTPClient() - rsAPI := base.RoomserverHTTPClient() - fsAPI := base.FederationAPIHTTPClient() - userAPI := base.UserAPIClient() - keyAPI := base.KeyServerHTTPClient() - - clientapi.AddPublicRoutes( - base, federation, rsAPI, asQuery, - transactions.New(), fsAPI, userAPI, userAPI, - keyAPI, nil, - ) - - base.SetupAndServeHTTP( - base.Cfg.ClientAPI.InternalAPI.Listen, - base.Cfg.ClientAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go deleted file mode 100644 index 48da42fbf..000000000 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/federationapi" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func FederationAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - federation := base.CreateFederationClient() - rsAPI := base.RoomserverHTTPClient() - keyAPI := base.KeyServerHTTPClient() - fsAPI := federationapi.NewInternalAPI(base, federation, rsAPI, base.Caches, nil, true) - keyRing := fsAPI.KeyRing() - - federationapi.AddPublicRoutes( - base, - userAPI, federation, keyRing, - rsAPI, fsAPI, keyAPI, nil, - ) - - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) - - base.SetupAndServeHTTP( - base.Cfg.FederationAPI.InternalAPI.Listen, - base.Cfg.FederationAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go deleted file mode 100644 index ad0bd0e54..000000000 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/keyserver" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - fsAPI := base.FederationAPIHTTPClient() - rsAPI := base.RoomserverHTTPClient() - intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) - intAPI.SetUserAPI(base.UserAPIClient()) - - keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) - - base.SetupAndServeHTTP( - base.Cfg.KeyServer.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/roomserver.go b/cmd/dendrite-polylith-multi/personalities/roomserver.go deleted file mode 100644 index 974559bd2..000000000 --- a/cmd/dendrite-polylith-multi/personalities/roomserver.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/roomserver" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func RoomServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - asAPI := base.AppserviceHTTPClient() - fsAPI := base.FederationAPIHTTPClient() - rsAPI := roomserver.NewInternalAPI(base) - rsAPI.SetFederationAPI(fsAPI, fsAPI.KeyRing()) - rsAPI.SetAppserviceAPI(asAPI) - roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI, base.EnableMetrics) - - base.SetupAndServeHTTP( - base.Cfg.RoomServer.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/syncapi.go b/cmd/dendrite-polylith-multi/personalities/syncapi.go deleted file mode 100644 index 41637fe1d..000000000 --- a/cmd/dendrite-polylith-multi/personalities/syncapi.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/syncapi" -) - -func SyncAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - - rsAPI := base.RoomserverHTTPClient() - - syncapi.AddPublicRoutes( - base, - userAPI, rsAPI, - base.KeyServerHTTPClient(), - ) - - base.SetupAndServeHTTP( - base.Cfg.SyncAPI.InternalAPI.Listen, - base.Cfg.SyncAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go deleted file mode 100644 index 1bc88cb5f..000000000 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" -) - -func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := userapi.NewInternalAPI( - base, &cfg.UserAPI, cfg.Derived.ApplicationServices, - base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), - base.PushGatewayHTTPClient(), - ) - - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) - - base.SetupAndServeHTTP( - base.Cfg.UserAPI.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 39b9320cb..174a80a3e 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -45,6 +45,10 @@ var ( const HEAD = "HEAD" +// The binary was renamed after v0.11.1, so everything after that should use the new name +var binaryChangeVersion, _ = semver.NewVersion("v0.11.1") +var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD" + // Embed the Dockerfile to use when building dendrite versions. // We cannot use the dockerfile associated with the repo with each version sadly due to changes in // Docker versions. Specifically, earlier Dendrite versions are incompatible with newer Docker clients @@ -54,12 +58,13 @@ const HEAD = "HEAD" const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build +ARG BINARY # Copy the build context to the repo as this is the right dendrite code. This is different to the # Complement Dockerfile which wgets a branch. COPY . . -RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/${BINARY} RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config RUN go build ./cmd/create-account @@ -88,22 +93,24 @@ done \n\ \n\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ -./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +./${BINARY} --really-enable-open-registration ${PARAMS} || ./${BINARY} ${PARAMS} \n\ ' > run_dendrite.sh && chmod +x run_dendrite.sh ENV SERVER_NAME=localhost +ENV BINARY=dendrite EXPOSE 8008 8448 -CMD /build/run_dendrite.sh ` +CMD /build/run_dendrite.sh` const DockerfileSQLite = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build +ARG BINARY # Copy the build context to the repo as this is the right dendrite code. This is different to the # Complement Dockerfile which wgets a branch. COPY . . -RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/${BINARY} RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config RUN go build ./cmd/create-account @@ -118,10 +125,11 @@ RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postg RUN echo '\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ -./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +./${BINARY} --really-enable-open-registration ${PARAMS} || ./${BINARY} ${PARAMS} \n\ ' > run_dendrite.sh && chmod +x run_dendrite.sh ENV SERVER_NAME=localhost +ENV BINARY=dendrite EXPOSE 8008 8448 CMD /build/run_dendrite.sh ` @@ -182,7 +190,7 @@ func downloadArchive(cli *http.Client, tmpDir, archiveURL string, dockerfile []b } // buildDendrite builds Dendrite on the branchOrTagName given. Returns the image ID or an error -func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, branchOrTagName string) (string, error) { +func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir string, branchOrTagName, binary string) (string, error) { var tarball *bytes.Buffer var err error // If a custom HEAD location is given, use that, else pull from github. Mostly useful for CI @@ -216,6 +224,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, log.Printf("%s: Building version %s\n", branchOrTagName, branchOrTagName) res, err := dockerClient.ImageBuild(context.Background(), tarball, types.ImageBuildOptions{ Tags: []string{"dendrite-upgrade"}, + BuildArgs: map[string]*string{ + "BINARY": &binary, + }, }) if err != nil { return "", fmt.Errorf("failed to start building image: %s", err) @@ -272,7 +283,7 @@ func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Ve return semVers, nil } -func calculateVersions(cli *http.Client, from, to string, direct bool) []string { +func calculateVersions(cli *http.Client, from, to string, direct bool) []*semver.Version { semvers, err := getAndSortVersionsFromGithub(cli) if err != nil { log.Fatalf("failed to collect semvers from github: %s", err) @@ -320,28 +331,25 @@ func calculateVersions(cli *http.Client, from, to string, direct bool) []string } semvers = semvers[:i+1] } - var versions []string - for _, sv := range semvers { - versions = append(versions, sv.Original()) - } + if to == HEAD { - versions = append(versions, HEAD) + semvers = append(semvers, latest) } if direct { - versions = []string{versions[0], versions[len(versions)-1]} + semvers = []*semver.Version{semvers[0], semvers[len(semvers)-1]} } - return versions + return semvers } -func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, baseTempDir string, concurrency int, branchOrTagNames []string) map[string]string { +func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, baseTempDir string, concurrency int, versions []*semver.Version) map[string]string { // concurrently build all versions, this can be done in any order. The mutex protects the map branchToImageID := make(map[string]string) var mu sync.Mutex var wg sync.WaitGroup wg.Add(concurrency) - ch := make(chan string, len(branchOrTagNames)) - for _, branchName := range branchOrTagNames { + ch := make(chan *semver.Version, len(versions)) + for _, branchName := range versions { ch <- branchName } close(ch) @@ -349,11 +357,13 @@ func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, b for i := 0; i < concurrency; i++ { go func() { defer wg.Done() - for branchName := range ch { + for version := range ch { + branchName, binary := versionToBranchAndBinary(version) + log.Printf("Building version %s with binary %s", branchName, binary) tmpDir := baseTempDir + alphaNumerics.ReplaceAllString(branchName, "") - imgID, err := buildDendrite(httpClient, dockerClient, tmpDir, branchName) + imgID, err := buildDendrite(httpClient, dockerClient, tmpDir, branchName, binary) if err != nil { - log.Fatalf("%s: failed to build dendrite image: %s", branchName, err) + log.Fatalf("%s: failed to build dendrite image: %s", version, err) } mu.Lock() branchToImageID[branchName] = imgID @@ -365,13 +375,14 @@ func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, b return branchToImageID } -func runImage(dockerClient *client.Client, volumeName, version, imageID string) (csAPIURL, containerID string, err error) { - log.Printf("%s: running image %s\n", version, imageID) +func runImage(dockerClient *client.Client, volumeName string, branchNameToImageID map[string]string, version *semver.Version) (csAPIURL, containerID string, err error) { + branchName, binary := versionToBranchAndBinary(version) + imageID := branchNameToImageID[branchName] ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() body, err := dockerClient.ContainerCreate(ctx, &container.Config{ Image: imageID, - Env: []string{"SERVER_NAME=hs1"}, + Env: []string{"SERVER_NAME=hs1", fmt.Sprintf("BINARY=%s", binary)}, Labels: map[string]string{ dendriteUpgradeTestLabel: "yes", }, @@ -384,7 +395,7 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) Target: "/var/lib/postgresql/9.6/main", }, }, - }, nil, nil, "dendrite_upgrade_test_"+version) + }, nil, nil, "dendrite_upgrade_test_"+branchName) if err != nil { return "", "", fmt.Errorf("failed to ContainerCreate: %s", err) } @@ -451,8 +462,8 @@ func destroyContainer(dockerClient *client.Client, containerID string) { } } -func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchToImageID map[string]string) error { - csAPIURL, containerID, err := runImage(dockerClient, volumeName, v, branchToImageID[v]) +func loadAndRunTests(dockerClient *client.Client, volumeName string, v *semver.Version, branchToImageID map[string]string) error { + csAPIURL, containerID, err := runImage(dockerClient, volumeName, branchToImageID, v) if err != nil { return fmt.Errorf("failed to run container for branch %v: %v", v, err) } @@ -470,9 +481,10 @@ func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchTo } // test that create-account is working -func testCreateAccount(dockerClient *client.Client, v string, containerID string) error { - createUser := strings.ToLower("createaccountuser-" + v) - log.Printf("%s: Creating account %s with create-account\n", v, createUser) +func testCreateAccount(dockerClient *client.Client, version *semver.Version, containerID string) error { + branchName, _ := versionToBranchAndBinary(version) + createUser := strings.ToLower("createaccountuser-" + branchName) + log.Printf("%s: Creating account %s with create-account\n", branchName, createUser) respID, err := dockerClient.ContainerExecCreate(context.Background(), containerID, types.ExecConfig{ AttachStderr: true, @@ -504,9 +516,21 @@ func testCreateAccount(dockerClient *client.Client, v string, containerID string return nil } -func verifyTests(dockerClient *client.Client, volumeName string, versions []string, branchToImageID map[string]string) error { +func versionToBranchAndBinary(version *semver.Version) (branchName, binary string) { + binary = "dendrite-monolith-server" + branchName = version.Original() + if version.GreaterThan(binaryChangeVersion) { + binary = "dendrite" + if version.Equal(latest) { + branchName = HEAD + } + } + return +} + +func verifyTests(dockerClient *client.Client, volumeName string, versions []*semver.Version, branchToImageID map[string]string) error { lastVer := versions[len(versions)-1] - csAPIURL, containerID, err := runImage(dockerClient, volumeName, lastVer, branchToImageID[lastVer]) + csAPIURL, containerID, err := runImage(dockerClient, volumeName, branchToImageID, lastVer) if err != nil { return fmt.Errorf("failed to run container for branch %v: %v", lastVer, err) } diff --git a/cmd/dendrite-upgrade-tests/tests.go b/cmd/dendrite-upgrade-tests/tests.go index 5c9589df2..03438bd4d 100644 --- a/cmd/dendrite-upgrade-tests/tests.go +++ b/cmd/dendrite-upgrade-tests/tests.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/Masterminds/semver/v3" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -22,7 +23,8 @@ type user struct { // - register alice and bob with branch name muxed into the localpart // - create a DM room for the 2 users and exchange messages // - create/join a public #global room and exchange messages -func runTests(baseURL, branchName string) error { +func runTests(baseURL string, v *semver.Version) error { + branchName, _ := versionToBranchAndBinary(v) // register 2 users users := []user{ { @@ -164,15 +166,16 @@ func runTests(baseURL, branchName string) error { } // verifyTestsRan checks that the HS has the right rooms/messages -func verifyTestsRan(baseURL string, branchNames []string) error { +func verifyTestsRan(baseURL string, versions []*semver.Version) error { log.Println("Verifying tests....") // check we can login as all users var resp *gomatrix.RespLogin - for _, branchName := range branchNames { + for _, version := range versions { client, err := gomatrix.NewClient(baseURL, "", "") if err != nil { return err } + branchName, _ := versionToBranchAndBinary(version) userLocalparts := []string{ "alice" + branchName, "bob" + branchName, @@ -224,7 +227,7 @@ func verifyTestsRan(baseURL string, branchNames []string) error { msgCount += 1 } } - wantMsgCount := len(branchNames) * 4 + wantMsgCount := len(versions) * 4 if msgCount != wantMsgCount { return fmt.Errorf("got %d messages in global room, want %d", msgCount, wantMsgCount) } diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go new file mode 100644 index 000000000..e8ff0a478 --- /dev/null +++ b/cmd/dendrite/main.go @@ -0,0 +1,103 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup" + basepkg "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/mscs" + "github.com/matrix-org/dendrite/userapi" +) + +var ( + httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") + httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") + certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") + keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") +) + +func main() { + cfg := setup.ParseFlags(true) + httpAddr := config.HTTPAddress("http://" + *httpBindAddr) + httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) + options := []basepkg.BaseDendriteOptions{} + + base := basepkg.NewBaseDendrite(cfg, options...) + defer base.Close() // nolint: errcheck + + federation := base.CreateFederationClient() + + rsAPI := roomserver.NewInternalAPI(base) + + fsAPI := federationapi.NewInternalAPI( + base, federation, rsAPI, base.Caches, nil, false, + ) + + keyRing := fsAPI.KeyRing() + + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) + + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this + // dependency. Other components also need updating after their dependencies are up. + rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetAppserviceAPI(asAPI) + rsAPI.SetUserAPI(userAPI) + + monolith := setup.Monolith{ + Config: base.Cfg, + Client: base.CreateClient(), + FedClient: federation, + KeyRing: keyRing, + + AppserviceAPI: asAPI, + // always use the concrete impl here even in -http mode because adding public routes + // must be done on the concrete impl not an HTTP client else fedapi will call itself + FederationAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + } + monolith.AddAllPublicRoutes(base) + + if len(base.Cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } + } + + // Expose the matrix APIs directly rather than putting them under a /api path. + go func() { + base.SetupAndServeHTTP(httpAddr, nil, nil) + }() + // Handle HTTPS if certificate and key are provided + if *certFile != "" && *keyFile != "" { + go func() { + base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) + }() + } + + // We want to block forever to let the HTTP and HTTPS handler serve the APIs + base.WaitForShutdown() +} diff --git a/cmd/dendrite-monolith-server/main_test.go b/cmd/dendrite/main_test.go similarity index 100% rename from cmd/dendrite-monolith-server/main_test.go rename to cmd/dendrite/main_test.go diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 33b18c471..b0707be11 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -18,7 +18,6 @@ func main() { dbURI := flag.String("db", "", "The DB URI to use for all components (PostgreSQL only)") dirPath := flag.String("dir", "./", "The folder to use for paths (like SQLite databases, media storage)") normalise := flag.String("normalise", "", "Normalise an existing configuration file by adding new/missing options and defaults") - polylith := flag.Bool("polylith", false, "Generate a config that makes sense for polylith deployments") flag.Parse() var cfg *config.Dendrite @@ -27,14 +26,14 @@ func main() { Version: config.Version, } cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: !*polylith, + Generate: true, + SingleDatabase: true, }) if *serverName != "" { cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) } uri := config.DataSource(*dbURI) - if *polylith || uri.IsSQLite() || uri == "" { + if uri.IsSQLite() || uri == "" { for name, db := range map[string]*config.DatabaseOptions{ "federationapi": &cfg.FederationAPI.Database, "keyserver": &cfg.KeyServer.Database, @@ -43,6 +42,7 @@ func main() { "roomserver": &cfg.RoomServer.Database, "syncapi": &cfg.SyncAPI.Database, "userapi": &cfg.UserAPI.AccountDatabase, + "relayapi": &cfg.RelayAPI.Database, } { if uri == "" { path := filepath.Join(*dirPath, fmt.Sprintf("dendrite_%s.db", name)) @@ -54,6 +54,9 @@ func main() { } else { cfg.Global.DatabaseOptions.ConnectionString = uri } + cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) + cfg.Global.JetStream.StoragePath = config.Path(*dirPath) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(*dirPath, "searchindex")) cfg.Logging = []config.LogrusHook{ { Type: "file", @@ -67,6 +70,7 @@ func main() { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationAPI.DisableTLSValidation = false + cfg.FederationAPI.DisableHTTPKeepalives = true // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) @@ -92,7 +96,7 @@ func main() { } } else { var err error - if cfg, err = config.Load(*normalise, !*polylith); err != nil { + if cfg, err = config.Load(*normalise); err != nil { panic(err) } } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index f8bb130c7..e3840bbcf 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -40,7 +40,7 @@ func main() { Level: "error", }) cfg.ClientAPI.RegistrationDisabled = true - base := base.NewBaseDendrite(cfg, "ResolveState", base.DisableMetrics) + base := base.NewBaseDendrite(cfg, base.DisableMetrics) args := flag.Args() fmt.Println("Room version", *roomVersion) @@ -87,7 +87,7 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs) if err != nil { panic(err) } @@ -145,7 +145,7 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs) if err != nil { panic(err) } @@ -165,7 +165,7 @@ func main() { } fmt.Println("Fetching", len(authEventIDs), "auth events") - authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs) + authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs) if err != nil { panic(err) } diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml deleted file mode 100644 index ecc3f4051..000000000 --- a/dendrite-sample.polylith.yaml +++ /dev/null @@ -1,417 +0,0 @@ -# This is the Dendrite configuration file. -# -# The configuration is split up into sections - each Dendrite component has a -# configuration section, in addition to the "global" section which applies to -# all components. - -# The version of the configuration file. -version: 2 - -# Global Matrix configuration. This configuration applies to all components. -global: - # The domain name of this homeserver. - server_name: localhost - - # The path to the signing private key file, used to sign requests and events. - # Note that this is NOT the same private key as used for TLS! To generate a - # signing key, use "./bin/generate-keys --private-key matrix_key.pem". - private_key: matrix_key.pem - - # The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) - # to old signing keys that were formerly in use on this domain name. These - # keys will not be used for federation request or event signing, but will be - # provided to any other homeserver that asks when trying to verify old events. - old_private_keys: - # If the old private key file is available: - # - private_key: old_matrix_key.pem - # expired_at: 1601024554498 - # If only the public key (in base64 format) and key ID are known: - # - public_key: mn59Kxfdq9VziYHSBzI7+EDPDcBS2Xl7jeUdiiQcOnM= - # key_id: ed25519:mykeyid - # expired_at: 1601024554498 - - # How long a remote server can cache our server signing key before requesting it - # again. Increasing this number will reduce the number of requests made by other - # servers for our key but increases the period that a compromised key will be - # considered valid by other homeservers. - key_validity_period: 168h0m0s - - # Configuration for in-memory caches. Caches can often improve performance by - # keeping frequently accessed items (like events, identifiers etc.) in memory - # rather than having to read them from the database. - cache: - # The estimated maximum size for the global cache in bytes, or in terabytes, - # gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or - # 'kb' suffix is specified. Note that this is not a hard limit, nor is it a - # memory limit for the entire process. A cache that is too small may ultimately - # provide little or no benefit. - max_size_estimated: 1gb - - # The maximum amount of time that a cache entry can live for in memory before - # it will be evicted and/or refreshed from the database. Lower values result in - # easier admission of new cache entries but may also increase database load in - # comparison to higher values, so adjust conservatively. Higher values may make - # it harder for new items to make it into the cache, e.g. if new rooms suddenly - # become popular. - max_age: 1h - - # The server name to delegate server-server communications to, with optional port - # e.g. localhost:443 - well_known_server_name: "" - - # The server name to delegate client-server communications to, with optional port - # e.g. localhost:443 - well_known_client_name: "" - - # Lists of domains that the server will trust as identity servers to verify third - # party identifiers such as phone numbers and email addresses. - trusted_third_party_id_servers: - - matrix.org - - vector.im - - # Disables federation. Dendrite will not be able to communicate with other servers - # in the Matrix federation and the federation API will not be exposed. - disable_federation: false - - # Configures the handling of presence events. Inbound controls whether we receive - # presence events from other servers, outbound controls whether we send presence - # events for our local users to other servers. - presence: - enable_inbound: false - enable_outbound: false - - # Configures phone-home statistics reporting. These statistics contain the server - # name, number of active users and some information on your deployment config. - # We use this information to understand how Dendrite is being used in the wild. - report_stats: - enabled: false - endpoint: https://matrix.org/report-usage-stats/push - - # Server notices allows server admins to send messages to all users on the server. - server_notices: - enabled: false - # The local part, display name and avatar URL (as a mxc:// URL) for the user that - # will send the server notices. These are visible to all users on the deployment. - local_part: "_server" - display_name: "Server Alerts" - avatar_url: "" - # The room name to be used when sending server notices. This room name will - # appear in user clients. - room_name: "Server Alerts" - - # Configuration for NATS JetStream - jetstream: - # A list of NATS Server addresses to connect to. If none are specified, an - # internal NATS server will be started automatically when running Dendrite in - # monolith mode. For polylith deployments, it is required to specify the address - # of at least one NATS Server node. - addresses: - - hostname:4222 - - # Disable the validation of TLS certificates of NATS. This is - # not recommended in production since it may allow NATS traffic - # to be sent to an insecure endpoint. - disable_tls_validation: false - - # The prefix to use for stream names for this homeserver - really only useful - # if you are running more than one Dendrite server on the same NATS deployment. - topic_prefix: Dendrite - - # Configuration for Prometheus metric collection. - metrics: - enabled: false - basic_auth: - username: metrics - password: metrics - - # Optional DNS cache. The DNS cache may reduce the load on DNS servers if there - # is no local caching resolver available for use. - dns_cache: - enabled: false - cache_size: 256 - cache_lifetime: "5m" # 5 minutes; https://pkg.go.dev/time@master#ParseDuration - -# Configuration for the Appservice API. -app_service_api: - internal_api: - listen: http://[::]:7777 # The listen address for incoming API requests - connect: http://app_service_api:7777 # The connect address for other components to use - - # Disable the validation of TLS certificates of appservices. This is - # not recommended in production since it may allow appservice traffic - # to be sent to an insecure endpoint. - disable_tls_validation: false - - # Appservice configuration files to load into this homeserver. - config_files: - # - /path/to/appservice_registration.yaml - -# Configuration for the Client API. -client_api: - internal_api: - listen: http://[::]:7771 # The listen address for incoming API requests - connect: http://client_api:7771 # The connect address for other components to use - external_api: - listen: http://[::]:8071 - - # Prevents new users from being able to register on this homeserver, except when - # using the registration shared secret below. - registration_disabled: true - - # Prevents new guest accounts from being created. Guest registration is also - # disabled implicitly by setting 'registration_disabled' above. - guests_disabled: true - - # If set, allows registration by anyone who knows the shared secret, regardless - # of whether registration is otherwise disabled. - registration_shared_secret: "" - - # Whether to require reCAPTCHA for registration. If you have enabled registration - # then this is HIGHLY RECOMMENDED to reduce the risk of your homeserver being used - # for coordinated spam attacks. - enable_registration_captcha: false - - # Settings for ReCAPTCHA. - recaptcha_public_key: "" - recaptcha_private_key: "" - recaptcha_bypass_secret: "" - - # To use hcaptcha.com instead of ReCAPTCHA, set the following parameters, otherwise just keep them empty. - # recaptcha_siteverify_api: "https://hcaptcha.com/siteverify" - # recaptcha_api_js_url: "https://js.hcaptcha.com/1/api.js" - # recaptcha_form_field: "h-captcha-response" - # recaptcha_sitekey_class: "h-captcha" - - - # TURN server information that this homeserver should send to clients. - turn: - turn_user_lifetime: "5m" - turn_uris: - # - turn:turn.server.org?transport=udp - # - turn:turn.server.org?transport=tcp - turn_shared_secret: "" - # If your TURN server requires static credentials, then you will need to enter - # them here instead of supplying a shared secret. Note that these credentials - # will be visible to clients! - # turn_username: "" - # turn_password: "" - - # Settings for rate-limited endpoints. Rate limiting kicks in after the threshold - # number of "slots" have been taken by requests from a specific host. Each "slot" - # will be released after the cooloff time in milliseconds. Server administrators - # and appservice users are exempt from rate limiting by default. - rate_limiting: - enabled: true - threshold: 20 - cooloff_ms: 500 - exempt_user_ids: - # - "@user:domain.com" - -# Configuration for the Federation API. -federation_api: - internal_api: - listen: http://[::]:7772 # The listen address for incoming API requests - connect: http://federation_api:7772 # The connect address for other components to use - external_api: - listen: http://[::]:8072 - database: - connection_string: postgresql://username:password@hostname/dendrite_federationapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # How many times we will try to resend a failed transaction to a specific server. The - # backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds etc. Once - # the max retries are exceeded, Dendrite will no longer try to send transactions to - # that server until it comes back to life and connects to us again. - send_max_retries: 16 - - # Disable the validation of TLS certificates of remote federated homeservers. Do not - # enable this option in production as it presents a security risk! - disable_tls_validation: false - - # Disable HTTP keepalives, which also prevents connection reuse. Dendrite will typically - # keep HTTP connections open to remote hosts for 5 minutes as they can be reused much - # more quickly than opening new connections each time. Disabling keepalives will close - # HTTP connections immediately after a successful request but may result in more CPU and - # memory being used on TLS handshakes for each new connection instead. - disable_http_keepalives: false - - # Perspective keyservers to use as a backup when direct key fetches fail. This may - # be required to satisfy key requests for servers that are no longer online when - # joining some rooms. - key_perspectives: - - server_name: matrix.org - keys: - - key_id: ed25519:auto - public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw - - key_id: ed25519:a_RXGa - public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ - - # This option will control whether Dendrite will prefer to look up keys directly - # or whether it should try perspective servers first, using direct fetches as a - # last resort. - prefer_direct_fetch: false - -# Configuration for the Key Server (for end-to-end encryption). -key_server: - internal_api: - listen: http://[::]:7779 # The listen address for incoming API requests - connect: http://key_server:7779 # The connect address for other components to use - database: - connection_string: postgresql://username:password@hostname/dendrite_keyserver?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Media API. -media_api: - internal_api: - listen: http://[::]:7774 # The listen address for incoming API requests - connect: http://media_api:7774 # The connect address for other components to use - external_api: - listen: http://[::]:8074 - database: - connection_string: postgresql://username:password@hostname/dendrite_mediaapi?sslmode=disable - max_open_conns: 5 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # Storage path for uploaded media. May be relative or absolute. - base_path: ./media_store - - # The maximum allowed file size (in bytes) for media uploads to this homeserver - # (0 = unlimited). If using a reverse proxy, ensure it allows requests at least - #this large (e.g. the client_max_body_size setting in nginx). - max_file_size_bytes: 10485760 - - # Whether to dynamically generate thumbnails if needed. - dynamic_thumbnails: false - - # The maximum number of simultaneous thumbnail generators to run. - max_thumbnail_generators: 10 - - # A list of thumbnail sizes to be generated for media content. - thumbnail_sizes: - - width: 32 - height: 32 - method: crop - - width: 96 - height: 96 - method: crop - - width: 640 - height: 480 - method: scale - -# Configuration for enabling experimental MSCs on this homeserver. -mscs: - mscs: - # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) - # - msc2946 # (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) - database: - connection_string: postgresql://username:password@hostname/dendrite_mscs?sslmode=disable - max_open_conns: 5 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Room Server. -room_server: - internal_api: - listen: http://[::]:7770 # The listen address for incoming API requests - connect: http://room_server:7770 # The connect address for other components to use - database: - connection_string: postgresql://username:password@hostname/dendrite_roomserver?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Sync API. -sync_api: - internal_api: - listen: http://[::]:7773 # The listen address for incoming API requests - connect: http://sync_api:7773 # The connect address for other components to use - external_api: - listen: http://[::]:8073 - database: - connection_string: postgresql://username:password@hostname/dendrite_syncapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # Configuration for the full-text search engine. - search: - # Whether or not search is enabled. - enabled: false - - # The path where the search index will be created in. - index_path: "./searchindex" - - # The language most likely to be used on the server - used when indexing, to - # ensure the returned results match expectations. A full list of possible languages - # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang - language: "en" - - # This option controls which HTTP header to inspect to find the real remote IP - # address of the client. This is likely required if Dendrite is running behind - # a reverse proxy server. - # real_ip_header: X-Real-IP - -# Configuration for the User API. -user_api: - internal_api: - listen: http://[::]:7781 # The listen address for incoming API requests - connect: http://user_api:7781 # The connect address for other components to use - account_database: - connection_string: postgresql://username:password@hostname/dendrite_userapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # The cost when hashing passwords on registration/login. Default: 10. Min: 4, Max: 31 - # See https://pkg.go.dev/golang.org/x/crypto/bcrypt for more information. - # Setting this lower makes registration/login consume less CPU resources at the cost - # of security should the database be compromised. Setting this higher makes registration/login - # consume more CPU resources but makes it harder to brute force password hashes. This value - # can be lowered if performing tests or on embedded Dendrite instances (e.g WASM builds). - bcrypt_cost: 10 - - # The length of time that a token issued for a relying party from - # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. - # The default lifetime is 3600000ms (60 minutes). - # openid_token_lifetime_ms: 3600000 - - # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. - # By default, any room aliases included in this list will be created as a publicly joinable room - # when the first user registers for the homeserver. If the room already exists, - # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. - # As Spaces are just rooms under the hood, Space aliases may also be used. - auto_join_rooms: - # - "#main:matrix.org" - -# Configuration for Opentracing. -# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on -# how this works and how to set it up. -tracing: - enabled: false - jaeger: - serviceName: "" - disabled: false - rpc_metrics: false - tags: [] - sampler: null - reporter: null - headers: null - baggage_restrictions: null - throttler: null - -# Logging configuration. The "std" logging type controls the logs being sent to -# stdout. The "file" logging type controls logs being written to a log folder on -# the disk. Supported log levels are "debug", "info", "warn", "error". -logging: - - type: std - level: info - - type: file - level: info - params: - path: ./logs diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.yaml similarity index 98% rename from dendrite-sample.monolith.yaml rename to dendrite-sample.yaml index d86e9da94..39be8d8b8 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.yaml @@ -38,7 +38,7 @@ global: # Global database connection pool, for PostgreSQL monolith deployments only. If # this section is populated then you can omit the "database" blocks in all other - # sections. For polylith deployments, or monolith deployments using SQLite databases, + # sections. For monolith deployments using SQLite databases, # you must configure the "database" block for each component instead. database: connection_string: postgresql://username:password@hostname/dendrite?sslmode=disable @@ -113,8 +113,7 @@ global: jetstream: # A list of NATS Server addresses to connect to. If none are specified, an # internal NATS server will be started automatically when running Dendrite in - # monolith mode. For polylith deployments, it is required to specify the address - # of at least one NATS Server node. + # monolith mode. addresses: # - localhost:4222 diff --git a/docs/FAQ.md b/docs/FAQ.md index 23f029535..ecf43d693 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -35,11 +35,6 @@ possible to migrate an existing Synapse deployment to Dendrite. No, Dendrite has a very different database schema to Synapse and the two are not interchangeable. -## Should I run a monolith or a polylith deployment? - -Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines, but this is an advanced configuration which we don't -recommend. - ## Can I configure which port Dendrite listens on? Yes, use the cli flag `-http-bind-address`. diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 509a8cbcf..a61786c1d 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -1,7 +1,7 @@ GEM remote: https://rubygems.org/ specs: - activesupport (6.0.5) + activesupport (6.0.6.1) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 0.7, < 2) minitest (~> 5.1) @@ -14,8 +14,8 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.6) - concurrent-ruby (1.1.10) + commonmarker (0.23.7) + concurrent-ruby (1.2.0) dnsruby (1.61.9) simpleidn (~> 0.1) em-websocket (0.5.3) @@ -229,7 +229,7 @@ GEM jekyll (>= 3.5, < 5.0) jekyll-feed (~> 0.9) jekyll-seo-tag (~> 2.1) - minitest (5.15.0) + minitest (5.17.0) multipart-post (2.1.1) nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) @@ -265,13 +265,13 @@ GEM thread_safe (0.3.6) typhoeus (1.4.0) ethon (>= 0.9.0) - tzinfo (1.2.10) + tzinfo (1.2.11) thread_safe (~> 0.1) unf (0.1.4) unf_ext unf_ext (0.0.8.1) unicode-display_width (1.8.0) - zeitwerk (2.5.4) + zeitwerk (2.6.6) PLATFORMS arm64-darwin-21 diff --git a/docs/INSTALL.md b/docs/INSTALL.md index add822108..ccfc58107 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -9,7 +9,5 @@ or alternatively, in the [installation](installation/) folder: 3. [Preparing database storage](installation/3_database.md) 4. [Generating signing keys](installation/4_signingkey.md) 5. [Installing as a monolith](installation/5_install_monolith.md) -6. [Installing as a polylith](installation/6_install_polylith.md) -7. [Populate the configuration](installation/7_configuration.md) -8. [Starting the monolith](installation/8_starting_monolith.md) -9. [Starting the polylith](installation/9_starting_polylith.md) +6. [Populate the configuration](installation/7_configuration.md) +7. [Starting the monolith](installation/8_starting_monolith.md) diff --git a/docs/development/coverage.md b/docs/development/coverage.md index 7a3b7cb9e..f3e39ddd7 100644 --- a/docs/development/coverage.md +++ b/docs/development/coverage.md @@ -57,22 +57,16 @@ github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% total: (statements) 53.7% ``` The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations, -which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required +which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, additional processing is required to remove packages which will never be tested and extension MSCs: ```bash # These commands are all similar but change which package paths are _removed_ from the output. -# For Postgres (monolith) +# For Postgres go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt -# For Postgres (polylith) -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt - -# For SQLite (monolith) +# For SQLite go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt - -# For SQLite (polylith) -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt ``` A total value can then be calculated using: diff --git a/docs/development/tracing/setup.md b/docs/development/tracing/setup.md index 06f89bf85..a9e90c643 100644 --- a/docs/development/tracing/setup.md +++ b/docs/development/tracing/setup.md @@ -46,10 +46,10 @@ tracing: param: 1 ``` -then run the monolith server with `--api true` to use polylith components which do tracing spans: +then run the monolith server: ``` -./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml --api true +./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml ``` ## Checking traces diff --git a/docs/installation/10_starting_polylith.md b/docs/installation/10_starting_polylith.md deleted file mode 100644 index 0c2e2af2b..000000000 --- a/docs/installation/10_starting_polylith.md +++ /dev/null @@ -1,73 +0,0 @@ ---- -title: Starting the polylith -parent: Installation -has_toc: true -nav_order: 10 -permalink: /installation/start/polylith ---- - -# Starting the polylith - -Once you have completed all of the preparation and installation steps, -you can start your Dendrite polylith deployment by starting the various components -using the `dendrite-polylith-multi` personalities. - -## Start the reverse proxy - -Ensure that your reverse proxy is started and is proxying the correct -endpoints to the correct components. Software such as [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) can be used for this purpose. A [sample configuration -for NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf) -is provided. - -## Starting the components - -Each component must be started individually: - -### Client API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml clientapi -``` - -### Sync API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml syncapi -``` - -### Media API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml mediaapi -``` - -### Federation API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml federationapi -``` - -### Roomserver - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml roomserver -``` - -### Appservice API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml appservice -``` - -### User API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml userapi -``` - -### Key server - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml keyserver -``` diff --git a/docs/installation/1_planning.md b/docs/installation/1_planning.md index 3aa5b4d85..36d90abda 100644 --- a/docs/installation/1_planning.md +++ b/docs/installation/1_planning.md @@ -16,12 +16,6 @@ Users can run Dendrite in one of two modes which dictate how these components ar server with generally low overhead. This mode dramatically simplifies deployment complexity and offers the best balance between performance and resource usage for low-to-mid volume deployments. -* **Polylith mode** runs all components in isolated processes. Components communicate through an external NATS - server and HTTP APIs, which incur considerable overhead. While this mode allows for more granular control of - resources dedicated toward individual processes, given the additional communications overhead, it is only - necessary for very large deployments. - -Given our current state of development, **we recommend monolith mode** for all deployments. ## Databases @@ -85,21 +79,15 @@ If using the PostgreSQL database engine, you should install PostgreSQL 12 or lat ### NATS Server -Monolith deployments come with a built-in [NATS Server](https://github.com/nats-io/nats-server) and -therefore do not need this to be manually installed. If you are planning a monolith installation, you +Dendrite comes with a built-in [NATS Server](https://github.com/nats-io/nats-server) and +therefore does not need this to be manually installed. If you are planning a monolith installation, you do not need to do anything. -Polylith deployments, however, currently need a standalone NATS Server installation with JetStream -enabled. - -To do so, follow the [NATS Server installation instructions](https://docs.nats.io/running-a-nats-service/introduction/installation) and then [start your NATS deployment](https://docs.nats.io/running-a-nats-service/introduction/running). JetStream must be enabled, either by passing the `-js` flag to `nats-server`, -or by specifying the `store_dir` option in the the `jetstream` configuration. ### Reverse proxy A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) is required for polylith deployments and is useful for monolith -deployments. Configuring those is not covered in this documentation, although sample configurations +[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring those is not covered in this documentation, although sample configurations for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided. diff --git a/docs/installation/6_install_polylith.md b/docs/installation/6_install_polylith.md deleted file mode 100644 index ec4a77628..000000000 --- a/docs/installation/6_install_polylith.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -title: Installing as a polylith -parent: Installation -has_toc: true -nav_order: 6 -permalink: /installation/install/polylith ---- - -# Installing as a polylith - -You can install the Dendrite polylith binary into `$GOPATH/bin` by using `go install`: - -```sh -go install ./cmd/dendrite-polylith-multi -``` - -Alternatively, you can specify a custom path for the binary to be written to using `go build`: - -```sh -go build -o /usr/local/bin/ ./cmd/dendrite-polylith-multi -``` - -The `dendrite-polylith-multi` binary is a "multi-personality" binary which can run as -any of the components depending on the supplied command line parameters. - -## Reverse proxy - -Polylith deployments require a reverse proxy in order to ensure that requests are -sent to the correct endpoint. You must ensure that a suitable reverse proxy is installed -and configured. - -Sample configurations are provided -for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy/polylith/Caddyfile) -and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf). \ No newline at end of file diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 19958c92f..5f123bfca 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -7,11 +7,10 @@ permalink: /installation/configuration # Configuring Dendrite -A YAML configuration file is used to configure Dendrite. Sample configuration files are +A YAML configuration file is used to configure Dendrite. A sample configuration file is present in the top level of the Dendrite repository: * [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) -* [`dendrite-sample.polylith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.polylith.yaml) You will need to duplicate the sample, calling it `dendrite.yaml` for example, and then tailor it to your installation. At a minimum, you will need to populate the following @@ -46,10 +45,9 @@ global: ## JetStream configuration Monolith deployments can use the built-in NATS Server rather than running a standalone -server. If you are building a polylith deployment, or you want to use a standalone NATS -Server anyway, you can also configure that too. +server. If you want to use a standalone NATS Server anyway, you can also configure that too. -### Built-in NATS Server (monolith only) +### Built-in NATS Server In the `global` section, under the `jetstream` key, ensure that no server addresses are configured and set a `storage_path` to a persistent folder on the filesystem: @@ -63,7 +61,7 @@ global: topic_prefix: Dendrite ``` -### Standalone NATS Server (monolith and polylith) +### Standalone NATS Server To use a standalone NATS Server instance, you will need to configure `addresses` field to point to the port that your NATS Server is listening on: @@ -86,7 +84,7 @@ one address in the `addresses` field. Configuring database connections varies based on the [database configuration](database) that you chose. -### Global connection pool (monolith with a single PostgreSQL database only) +### Global connection pool If you are running a monolith deployment and want to use a single connection pool to a single PostgreSQL database, then you must uncomment and configure the `database` section @@ -109,7 +107,7 @@ override the `global` database configuration. ### Per-component connections (all other configurations) -If you are building a polylith deployment, are using SQLite databases or separate PostgreSQL +If you are are using SQLite databases or separate PostgreSQL databases per component, then you must instead configure the `database` sections under each of the component blocks ,e.g. under the `app_service_api`, `federation_api`, `key_server`, `media_api`, `mscs`, `room_server`, `sync_api` and `user_api` blocks. diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 50d0339e4..e4c0b2714 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,6 +18,7 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI + P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -30,7 +31,6 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error - PerformWakeupServers( ctx context.Context, request *PerformWakeupServersRequest, @@ -71,6 +71,29 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationAPI interface { + // Get the relay servers associated for the given server. + P2PQueryRelayServers( + ctx context.Context, + request *P2PQueryRelayServersRequest, + response *P2PQueryRelayServersResponse, + ) error + + // Add relay server associations to the given server. + P2PAddRelayServers( + ctx context.Context, + request *P2PAddRelayServersRequest, + response *P2PAddRelayServersResponse, + ) error + + // Remove relay server associations from the given server. + P2PRemoveRelayServers( + ctx context.Context, + request *P2PRemoveRelayServersRequest, + response *P2PRemoveRelayServersResponse, + ) error +} + // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError @@ -82,6 +105,7 @@ type KeyserverFederationAPI interface { // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { + P2PFederationClient gomatrixserverlib.KeyClient SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) @@ -110,6 +134,11 @@ type FederationClient interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationClient interface { + P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) +} + // FederationClientError is returned from FederationClient methods in the event of a problem. type FederationClientError struct { Err string @@ -233,3 +262,27 @@ type InputPublicKeysRequest struct { type InputPublicKeysResponse struct { } + +type P2PQueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type P2PQueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersResponse struct { +} + +type P2PRemoveRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PRemoveRelayServersResponse struct { +} diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 601257d4b..7d9df3d78 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -26,11 +26,11 @@ import ( "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // KeyChangeConsumer consumes events that originate in key server. diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 0c1080afa..82a4db3f7 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms msg := msgs[0] // Guaranteed to exist if onMessage is called receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) - // Only handle events we care about - if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek { + // Only handle events we care about, avoids unneeded unmarshalling + switch receivedType { + case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom: + default: return true } @@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return false } + case api.OutputTypePurgeRoom: + log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API") + if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API") + } else { + logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API") + } + default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -195,7 +206,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == gomatrixserverlib.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 87eb751f5..ec482659a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -17,20 +17,17 @@ package federationapi import ( "time" - "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/federationapi/inthttp" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -41,21 +38,14 @@ import ( "github.com/matrix-org/dendrite/federationapi/routing" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI, enableMetrics bool) { - inthttp.AddRoutes(intAPI, router, enableMetrics) -} - // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.FederationKeyAPI, servers federationAPI.ServersInRoomProvider, ) { cfg := &base.Cfg.FederationAPI @@ -85,12 +75,9 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - cfg, + base, rsAPI, f, keyRing, - federation, userAPI, keyAPI, mscCfg, + federation, userAPI, mscCfg, servers, producer, ) } @@ -116,7 +103,10 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) + stats := statistics.NewStatistics( + federationDB, + cfg.FederationMaxRetries+1, + cfg.P2PFederationRetriesUntilAssumedOffline+1) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index cc03cdece..bb6ee8935 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -77,8 +77,8 @@ func TestMain(m *testing.M) { // API to work. cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name) cfg.Global.PrivateKey = testPriv @@ -109,7 +109,7 @@ func TestMain(m *testing.M) { ) // Finally, build the server key APIs. - sbase := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + sbase := base.NewBaseDendrite(cfg, base.DisableMetrics) s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, nil, true) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 68a06a033..57d4b9644 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -17,13 +17,13 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type fedRoomserverAPI struct { @@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the // federationapi is incorrectly waiting for an output room event to arrive to update the joined // hosts table. - key := keyapi.DeviceMessage{ - Type: keyapi.TypeDeviceKeyUpdate, - DeviceKeys: &keyapi.DeviceKeys{ + key := userapi.DeviceMessage{ + Type: userapi.TypeDeviceKeyUpdate, + DeviceKeys: &userapi.DeviceKeys{ UserID: joiningUser.ID, DeviceID: "MY_DEVICE", DisplayName: "BLARGLE", @@ -266,19 +266,19 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { _, privKey, _ := ed25519.GenerateKey(nil) cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) cfg.Global.KeyID = gomatrixserverlib.KeyID("ed25519:auto") cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - base := base.NewBaseDendrite(cfg, "Monolith") + b := base.NewBaseDendrite(cfg, base.DisableMetrics) keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) - baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil) + baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 14056eafc..99773a750 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -109,13 +109,14 @@ func NewFederationInternalAPI( func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) - until, blacklisted := stats.BackoffInfo() - if blacklisted { + if stats.Blacklisted() { return stats, &api.FederationClientError{ Blacklisted: true, } } + now := time.Now() + until := stats.BackoffInfo() if until != nil && now.Before(*until) { return stats, &api.FederationClientError{ RetryAfter: time.Until(*until), @@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( RetryAfter: retryAfter, } } - stats.Success() + stats.Success(statistics.SendDirect) return res, nil } @@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) - if _, blacklisted := stats.BackoffInfo(); blacklisted { + if blacklisted := stats.Blacklisted(); blacklisted { return stats, &api.FederationClientError{ Err: fmt.Sprintf("server %q is blacklisted", s), Blacklisted: true, diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go new file mode 100644 index 000000000..49137e2d8 --- /dev/null +++ b/federationapi/internal/federationclient_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 +) + +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { + t.queryKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespQueryKeys{}, nil +} + +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { + t.claimKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespClaimKeys{}, nil +} + +func TestFederationClientQueryKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysFailure(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{shouldFail: true} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientClaimKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.claimKeysCalled) +} + +func TestFederationClientClaimKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.claimKeysCalled) +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d86d07e03..dadb2b2b3 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) (err error) { + if !r.shouldAttemptDirectFederation(request.ServerName) { + return fmt.Errorf("relay servers have no meaningful response for directory lookup.") + } + dir, err := r.federation.LookupRoomAlias( ctx, r.cfg.Matrix.ServerName, @@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( } response.RoomID = dir.RoomID response.ServerNames = dir.Servers - r.statistics.ForServer(request.ServerName).Success() + r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect) return nil } @@ -116,8 +121,6 @@ func (r *FederationInternalAPI) PerformJoin( var httpErr gomatrix.HTTPError if ok := errors.As(lastErr, &httpErr); ok { httpErr.Message = string(httpErr.Contents) - // Clear the wrapped error, else serialising to JSON (in polylith mode) will fail - httpErr.WrappedError = nil response.LastError = &httpErr } else { response.LastError = &gomatrix.HTTPError{ @@ -144,6 +147,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for join.") + } + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) if err != nil { return err @@ -164,7 +171,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.MakeJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" @@ -219,7 +226,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.SendJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // If the remote server returned an event in the "event" key of // the send_join request then we should use that instead. It may @@ -382,8 +389,6 @@ func (r *FederationInternalAPI) PerformOutboundPeek( var httpErr gomatrix.HTTPError if ok := errors.As(lastErr, &httpErr); ok { httpErr.Message = string(httpErr.Contents) - // Clear the wrapped error, else serialising to JSON (in polylith mode) will fail - httpErr.WrappedError = nil response.LastError = &httpErr } else { response.LastError = &gomatrix.HTTPError{ @@ -407,6 +412,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( serverName gomatrixserverlib.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for outbound peek.") + } + // create a unique ID for this peek. // for now we just use the room ID again. In future, if we ever // support concurrent peeks to the same room with different filters @@ -446,7 +455,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.Peek: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Work out if we support the room version that has been supplied in // the peek response. @@ -516,6 +525,10 @@ func (r *FederationInternalAPI) PerformLeave( // Try each server that we were provided until we land on one that // successfully completes the make-leave send-leave dance. for _, serverName := range request.ServerNames { + if !r.shouldAttemptDirectFederation(serverName) { + continue + } + // Try to perform a make_leave using the information supplied in the // request. respMakeLeave, err := r.federation.MakeLeave( @@ -585,7 +598,7 @@ func (r *FederationInternalAPI) PerformLeave( continue } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) return nil } @@ -616,6 +629,12 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } + // TODO (devon): This should be allowed via a relay. Currently only transactions + // can be sent to relays. Would need to extend relays to handle invites. + if !r.shouldAttemptDirectFederation(destination) { + return fmt.Errorf("relay servers have no meaningful response for invite.") + } + logrus.WithFields(logrus.Fields{ "event_id": request.Event.EventID(), "user_id": *request.Event.StateKey(), @@ -682,12 +701,8 @@ func (r *FederationInternalAPI) PerformWakeupServers( func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - // Check the statistics cache for the blacklist status to prevent hitting - // the database unnecessarily. - if r.queues.IsServerBlacklisted(srv) { - _ = r.db.RemoveServerFromBlacklist(srv) - } - r.queues.RetryServer(srv) + wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() + r.queues.RetryServer(srv, wasBlacklisted) } } @@ -719,7 +734,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -802,3 +819,61 @@ func federatedAuthProvider( return returning, nil } } + +// P2PQueryRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + logrus.Infof("Getting relay servers for: %s", request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) + if err != nil { + return err + } + + response.RelayServers = relayServers + return nil +} + +// P2PAddRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PAddRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +// P2PRemoveRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PRemoveRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { + var shouldRelay bool + stats := r.statistics.ForServer(destination) + if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { + shouldRelay = true + } + + return !shouldRelay +} diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go new file mode 100644 index 000000000..e6e366f99 --- /dev/null +++ b/federationapi/internal/perform_test.go @@ -0,0 +1,231 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + api.FederationClient + queryKeysCalled bool + claimKeysCalled bool + shouldFail bool +} + +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return gomatrixserverlib.RespDirectory{}, nil +} + +func TestPerformWakeupServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.AddServerToBlacklist(server) + testDB.SetServerAssumedOffline(context.Background(), server) + blacklisted, err := testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.True(t, blacklisted) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.True(t, offline) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{server}, + } + res := api.PerformWakeupServersResponse{} + err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) + assert.NoError(t, err) + + blacklisted, err = testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.False(t, blacklisted) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.False(t, offline) +} + +func TestQueryRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PQueryRelayServersRequest{ + Server: server, + } + res := api.P2PQueryRelayServersResponse{} + err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} + +func TestRemoveRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PRemoveRelayServersRequest{ + Server: server, + RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + } + res := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) + assert.NoError(t, err) + assert.Equal(t, 1, len(finalRelays)) + assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) +} + +func TestPerformDirectoryLookup(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: "server", + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.NoError(t, err) +} + +func TestPerformDirectoryLookupRelaying(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: server, + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: server, + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.Error(t, err) +} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go deleted file mode 100644 index 6eefdc7cd..000000000 --- a/federationapi/inthttp/client.go +++ /dev/null @@ -1,512 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" -) - -// HTTP paths for the internal HTTP API -const ( - FederationAPIQueryJoinedHostServerNamesInRoomPath = "/federationapi/queryJoinedHostServerNamesInRoom" - FederationAPIQueryServerKeysPath = "/federationapi/queryServerKeys" - - FederationAPIPerformDirectoryLookupRequestPath = "/federationapi/performDirectoryLookup" - FederationAPIPerformJoinRequestPath = "/federationapi/performJoinRequest" - FederationAPIPerformLeaveRequestPath = "/federationapi/performLeaveRequest" - FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" - FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" - FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" - FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" - - FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" - FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" - FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" - FederationAPIBackfillPath = "/federationapi/client/backfill" - FederationAPILookupStatePath = "/federationapi/client/lookupState" - FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" - FederationAPILookupMissingEventsPath = "/federationapi/client/lookupMissingEvents" - FederationAPIGetEventPath = "/federationapi/client/getEvent" - FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" - FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" - FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" - FederationAPIGetEventAuthPath = "/federationapi/client/getEventAuth" - - FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey" - FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey" -) - -// NewFederationAPIClient creates a FederationInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.FederationInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is ") - } - return &httpFederationInternalAPI{ - federationAPIURL: federationSenderURL, - httpClient: httpClient, - cache: cache, - }, nil -} - -type httpFederationInternalAPI struct { - federationAPIURL string - httpClient *http.Client - cache caching.ServerKeyCache -} - -// Handle an instruction to make_leave & send_leave with a remote server. -func (h *httpFederationInternalAPI) PerformLeave( - ctx context.Context, - request *api.PerformLeaveRequest, - response *api.PerformLeaveResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle sending an invite to a remote server. -func (h *httpFederationInternalAPI) PerformInvite( - ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle starting a peek on a remote server. -func (h *httpFederationInternalAPI) PerformOutboundPeek( - ctx context.Context, - request *api.PerformOutboundPeekRequest, - response *api.PerformOutboundPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI -func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *api.QueryJoinedHostServerNamesInRoomRequest, - response *api.QueryJoinedHostServerNamesInRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to make_join & send_join with a remote server. -func (h *httpFederationInternalAPI) PerformJoin( - ctx context.Context, - request *api.PerformJoinRequest, - response *api.PerformJoinResponse, -) { - if err := httputil.CallInternalRPCAPI( - "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath, - h.httpClient, ctx, request, response, - ); err != nil { - response.LastError = &gomatrix.HTTPError{ - Message: err.Error(), - Code: 0, - WrappedError: err, - } - } -} - -// Handle an instruction to make_join & send_join with a remote server. -func (h *httpFederationInternalAPI) PerformDirectoryLookup( - ctx context.Context, - request *api.PerformDirectoryLookupRequest, - response *api.PerformDirectoryLookupResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. -func (h *httpFederationInternalAPI) PerformBroadcastEDU( - ctx context.Context, - request *api.PerformBroadcastEDURequest, - response *api.PerformBroadcastEDUResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to remove the respective servers from being blacklisted. -func (h *httpFederationInternalAPI) PerformWakeupServers( - ctx context.Context, - request *api.PerformWakeupServersRequest, - response *api.PerformWakeupServersResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers, - h.httpClient, ctx, request, response, - ) -} - -type getUserDevices struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - UserID string -} - -func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, -) (gomatrixserverlib.RespUserDevices, error) { - return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( - "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, - ctx, &getUserDevices{ - S: s, - Origin: origin, - UserID: userID, - }, - ) -} - -type claimKeys struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - OneTimeKeys map[string]map[string]string -} - -func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, -) (gomatrixserverlib.RespClaimKeys, error) { - return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( - "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, - ctx, &claimKeys{ - S: s, - Origin: origin, - OneTimeKeys: oneTimeKeys, - }, - ) -} - -type queryKeys struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - Keys map[string][]string -} - -func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, -) (gomatrixserverlib.RespQueryKeys, error) { - return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( - "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, - ctx, &queryKeys{ - S: s, - Origin: origin, - Keys: keys, - }, - ) -} - -type backfill struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - Limit int - EventIDs []string -} - -func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, -) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( - "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, - ctx, &backfill{ - S: s, - Origin: origin, - RoomID: roomID, - Limit: limit, - EventIDs: eventIDs, - }, - ) -} - -type lookupState struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - EventID string - RoomVersion gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, -) (gomatrixserverlib.RespState, error) { - return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( - "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, - ctx, &lookupState{ - S: s, - Origin: origin, - RoomID: roomID, - EventID: eventID, - RoomVersion: roomVersion, - }, - ) -} - -type lookupStateIDs struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - EventID string -} - -func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, -) (gomatrixserverlib.RespStateIDs, error) { - return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( - "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, - ctx, &lookupStateIDs{ - S: s, - Origin: origin, - RoomID: roomID, - EventID: eventID, - }, - ) -} - -type lookupMissingEvents struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - Missing gomatrixserverlib.MissingEvents - RoomVersion gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, - missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.RespMissingEvents, err error) { - return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( - "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, - ctx, &lookupMissingEvents{ - S: s, - Origin: origin, - RoomID: roomID, - Missing: missing, - RoomVersion: roomVersion, - }, - ) -} - -type getEvent struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - EventID string -} - -func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, -) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( - "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, - ctx, &getEvent{ - S: s, - Origin: origin, - EventID: eventID, - }, - ) -} - -type getEventAuth struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomVersion gomatrixserverlib.RoomVersion - RoomID string - EventID string -} - -func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, origin, s gomatrixserverlib.ServerName, - roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, -) (gomatrixserverlib.RespEventAuth, error) { - return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( - "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, - ctx, &getEventAuth{ - S: s, - Origin: origin, - RoomVersion: roomVersion, - RoomID: roomID, - EventID: eventID, - }, - ) -} - -func (h *httpFederationInternalAPI) QueryServerKeys( - ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath, - h.httpClient, ctx, req, res, - ) -} - -type lookupServerKeys struct { - S gomatrixserverlib.ServerName - KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp -} - -func (h *httpFederationInternalAPI) LookupServerKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) ([]gomatrixserverlib.ServerKeys, error) { - return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError]( - "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient, - ctx, &lookupServerKeys{ - S: s, - KeyRequests: keyRequests, - }, - ) -} - -type eventRelationships struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2836EventRelationshipsRequest - RoomVer gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, - roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( - "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, - ctx, &eventRelationships{ - S: s, - Origin: origin, - Req: r, - RoomVer: roomVersion, - }, - ) -} - -type spacesReq struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - SuggestedOnly bool - RoomID string -} - -func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, -) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( - "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, - ctx, &spacesReq{ - S: dst, - Origin: origin, - SuggestedOnly: suggestedOnly, - RoomID: roomID, - }, - ) -} - -func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { - // This is a bit of a cheat - we tell gomatrixserverlib that this API is - // both the key database and the key fetcher. While this does have the - // rather unfortunate effect of preventing gomatrixserverlib from handling - // key fetchers directly, we can at least reimplement this behaviour on - // the other end of the API. - return &gomatrixserverlib.KeyRing{ - KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{}, - } -} - -func (s *httpFederationInternalAPI) FetcherName() string { - return "httpServerKeyInternalAPI" -} - -func (s *httpFederationInternalAPI) StoreKeys( - _ context.Context, - results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, -) error { - // Run in a background context - we don't want to stop this work just - // because the caller gives up waiting. - ctx := context.Background() - request := api.InputPublicKeysRequest{ - Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), - } - response := api.InputPublicKeysResponse{} - for req, res := range results { - request.Keys[req] = res - s.cache.StoreServerKey(req, res) - } - return s.InputPublicKeys(ctx, &request, &response) -} - -func (s *httpFederationInternalAPI) FetchKeys( - _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - // Run in a background context - we don't want to stop this work just - // because the caller gives up waiting. - ctx := context.Background() - result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) - request := api.QueryPublicKeysRequest{ - Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp), - } - response := api.QueryPublicKeysResponse{ - Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), - } - for req, ts := range requests { - if res, ok := s.cache.GetServerKey(req, ts); ok { - result[req] = res - continue - } - request.Requests[req] = ts - } - err := s.QueryPublicKeys(ctx, &request, &response) - if err != nil { - return nil, err - } - for req, res := range response.Results { - result[req] = res - s.cache.StoreServerKey(req, res) - } - return result, nil -} - -func (h *httpFederationInternalAPI) InputPublicKeys( - ctx context.Context, - request *api.InputPublicKeysRequest, - response *api.InputPublicKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpFederationInternalAPI) QueryPublicKeys( - ctx context.Context, - request *api.QueryPublicKeysRequest, - response *api.QueryPublicKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go deleted file mode 100644 index 9068dc400..000000000 --- a/federationapi/inthttp/server.go +++ /dev/null @@ -1,257 +0,0 @@ -package inthttp - -import ( - "context" - "encoding/json" - "net/http" - "net/url" - - "github.com/gorilla/mux" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux. -// nolint:gocyclo -func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { - internalAPIMux.Handle( - FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", enableMetrics, intAPI.QueryJoinedHostServerNamesInRoom), - ) - - internalAPIMux.Handle( - FederationAPIPerformInviteRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", enableMetrics, intAPI.PerformInvite), - ) - - internalAPIMux.Handle( - FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", enableMetrics, intAPI.PerformLeave), - ) - - internalAPIMux.Handle( - FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", enableMetrics, intAPI.PerformDirectoryLookup), - ) - - internalAPIMux.Handle( - FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", enableMetrics, intAPI.PerformBroadcastEDU), - ) - - internalAPIMux.Handle( - FederationAPIPerformWakeupServers, - httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", enableMetrics, intAPI.PerformWakeupServers), - ) - - internalAPIMux.Handle( - FederationAPIPerformJoinRequestPath, - httputil.MakeInternalRPCAPI( - "FederationAPIPerformJoinRequest", enableMetrics, - func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { - intAPI.PerformJoin(ctx, req, res) - return nil - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetUserDevicesPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetUserDevices", enableMetrics, - func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIClaimKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPIClaimKeys", enableMetrics, - func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIQueryKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPIQueryKeys", enableMetrics, - func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIBackfillPath, - httputil.MakeInternalProxyAPI( - "FederationAPIBackfill", enableMetrics, - func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupStatePath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupState", enableMetrics, - func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupStateIDsPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupStateIDs", enableMetrics, - func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupMissingEventsPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupMissingEvents", enableMetrics, - func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetEventPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetEvent", enableMetrics, - func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetEventAuthPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetEventAuth", enableMetrics, - func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIQueryServerKeysPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", enableMetrics, intAPI.QueryServerKeys), - ) - - internalAPIMux.Handle( - FederationAPILookupServerKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupServerKeys", enableMetrics, - func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { - res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIEventRelationshipsPath, - httputil.MakeInternalProxyAPI( - "FederationAPIMSC2836EventRelationships", enableMetrics, - func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPISpacesSummaryPath, - httputil.MakeInternalProxyAPI( - "FederationAPIMSC2946SpacesSummary", enableMetrics, - func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) - return &res, federationClientError(err) - }, - ), - ) - - // TODO: Look at this shape - internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { - request := api.QueryPublicKeysRequest{} - response := api.QueryPublicKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - keys, err := intAPI.FetchKeys(req.Context(), request.Requests) - if err != nil { - return util.ErrorResponse(err) - } - response.Results = keys - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - - // TODO: Look at this shape - internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIInputPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { - request := api.InputPublicKeysRequest{} - response := api.InputPublicKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.StoreKeys(req.Context(), request.Keys); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} - -func federationClientError(err error) error { - switch ferr := err.(type) { - case nil: - return nil - case api.FederationClientError: - return &ferr - case *api.FederationClientError: - return ferr - case gomatrix.HTTPError: - return &api.FederationClientError{ - Code: ferr.Code, - } - case *url.Error: // e.g. certificate error, unable to connect - return &api.FederationClientError{ - Err: ferr.Error(), - Code: 400, - } - default: - // We don't know what exactly failed, but we probably don't - // want to retry the request immediately in the device list updater - return &api.FederationClientError{ - Err: err.Error(), - Code: 400, - } - } -} diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 7cce13a7d..6bcfafa39 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -41,7 +41,7 @@ type SyncAPIProducer struct { TopicSigningKeyUpdate string JetStream nats.JetStreamContext Config *config.FederationAPI - UserAPI userapi.UserInternalAPI + UserAPI userapi.FederationUserAPI } func (p *SyncAPIProducer) SendReceipt( diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4a87fe99..12e6db9fa 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, + pdu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, + edu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.dbReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.dbReceipt.String()] = struct{}{} } overflowed := false @@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. _, blacklisted := oq.statistics.Failure() @@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() { return } } else { - oq.handleTransactionSuccess(pduCount, eduCount) + oq.handleTransactionSuccess(pduCount, eduCount, sendMethod) } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. -// Returns an error if the transaction wasn't sent. +// Returns an error if the transaction wasn't sent. And whether the success +// was to a relay server or not. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) error { +) (err error, sendMethod statistics.SendMethod) { // Create the transaction. t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) @@ -407,7 +408,52 @@ func (oq *destinationQueue) nextTransaction( // Try to send the transaction to the destination server. ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() - _, err := oq.client.SendTransaction(ctx, t) + + relayServers := oq.statistics.KnownRelayServers() + hasRelayServers := len(relayServers) > 0 + shouldSendToRelays := oq.statistics.AssumedOffline() && hasRelayServers + if !shouldSendToRelays { + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + } else { + // Try sending directly to the destination first in case they came back online. + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + if err != nil { + // The destination is still offline, try sending to relays. + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + + // TODO : what about if the dest comes back online but can't see their relay? + // How do I sync with the dest in that case? + // Should change the database to have a "relay success" flag on events and if + // I see the node back online, maybe directly send through the backlog of events + // with "relay success"... could lead to duplicate events, but only those that + // I sent. And will lead to a much more consistent experience. + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } + } switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. @@ -427,7 +473,7 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return nil + return nil, sendMethod case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. @@ -437,13 +483,13 @@ func (oq *destinationQueue) nextTransaction( // to a 400-ish error code := errResponse.Code logrus.Debug("Transaction failed with HTTP", code) - return err + return err, sendMethod default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return err + return err, sendMethod } } @@ -453,7 +499,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -474,8 +520,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. @@ -487,7 +533,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.dbReceipt) } // Do the same for pending EDUS in the queue. @@ -497,7 +543,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.dbReceipt) } return t, pduReceipts, eduReceipts @@ -530,10 +576,11 @@ func (oq *destinationQueue) blacklistDestination() { // handleTransactionSuccess updates the cached event queues as well as the success and // backoff information for this server. -func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() + + oq.statistics.Success(sendMethod) oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 75b1b36be..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { @@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU( return nil } -// IsServerBlacklisted returns whether or not the provided server is currently -// blacklisted. -func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { - return oqs.statistics.ForServer(srv).Blacklisted() -} - // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { if oqs.disabled { return } - serverStatistics := oqs.statistics.ForServer(srv) - forceWakeup := serverStatistics.Blacklisted() - serverStatistics.RemoveBlacklist() - serverStatistics.ClearBackoff() - if queue := oqs.getQueue(srv); queue != nil { - queue.wakeQueueIfEventsPending(forceWakeup) + queue.wakeQueueIfEventsPending(wasBlacklisted) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index c317edc21..bccfb3428 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "sync" "testing" "time" @@ -26,13 +25,11 @@ import ( "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" @@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := createDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} @@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -type fakeDatabase struct { - storage.Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -var nidMutex sync.Mutex -var nid = int64(0) - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - type stubFederationRoomServerAPI struct { rsapi.FederationRoomserverAPI } @@ -290,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont type stubFederationClient struct { api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 + shouldTxSucceed bool + shouldTxRelaySucceed bool + txCount atomic.Uint32 + txRelayCount atomic.Uint32 } func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { @@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return gomatrixserverlib.RespSend{}, result } +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { + var result error + if !f.shouldTxRelaySucceed { + result = fmt.Errorf("relay transaction failed") + } + + f.txRelayCount.Add(1) + return gomatrixserverlib.EmptyResp{}, result +} + func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` @@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), + shouldTxSucceed: shouldTxSucceed, + shouldTxRelaySucceed: shouldTxRelaySucceed, + txCount: *atomic.NewUint32(0), + txRelayCount: *atomic.NewUint32(0), } rs := &stubFederationRoomServerAPI{} - stats := statistics.NewStatistics(db, failuresUntilBlacklist) + + stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{ { KeyID: "ed21019:auto", @@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) assert.NoError(t, dbErrPDU) @@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} + +func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index ce8b06b70..871d26cd4 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -17,7 +17,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -26,11 +26,11 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - keyAPI keyapi.FederationKeyAPI, + keyAPI api.FederationKeyAPI, userID string, ) util.JSONResponse { - var res keyapi.QueryDeviceMessagesResponse - if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + var res api.QueryDeviceMessagesResponse + if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{ UserID: userID, }, &res); err != nil { return util.ErrorResponse(err) @@ -40,12 +40,12 @@ func GetUserDevices( return jsonerror.InternalServerError() } - sigReq := &keyapi.QuerySignaturesRequest{ + sigReq := &api.QuerySignaturesRequest{ TargetIDs: map[string][]gomatrixserverlib.KeyID{ userID: {}, }, } - sigRes := &keyapi.QuerySignaturesResponse{} + sigRes := &api.QuerySignaturesResponse{} for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index dc262cfde..2885cc916 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -22,8 +22,8 @@ import ( clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go new file mode 100644 index 000000000..3b9d576bf --- /dev/null +++ b/federationapi/routing/profile_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeUserAPI struct { + userAPI.FederationUserAPI +} + +func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { + return nil +} + +func TestHandleQueryProfile(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go new file mode 100644 index 000000000..d839a16b8 --- /dev/null +++ b/federationapi/routing/query_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedclient "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeFedClient struct { + fedclient.FederationClient +} + +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return +} + +func TestHandleQueryDirectory(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 0a3ab7a88..324740ddc 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -29,9 +29,9 @@ import ( "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -40,6 +40,12 @@ import ( "github.com/sirupsen/logrus" ) +const ( + SendRouteName = "Send" + QueryDirectoryRouteName = "QueryDirectory" + QueryProfileRouteName = "QueryProfile" +) + // Setup registers HTTP handlers with the given ServeMux. // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled @@ -49,21 +55,26 @@ import ( // applied: // nolint: gocyclo func Setup( - fedMux, keyMux, wkMux *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, - keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) { - prometheus.MustRegister( - pduCountTotal, eduCountTotal, - ) + fedMux := base.PublicFederationAPIMux + keyMux := base.PublicKeyAPIMux + wkMux := base.PublicWellKnownAPIMux + cfg := &base.Cfg.FederationAPI + + if base.EnableMetrics { + prometheus.MustRegister( + internal.PDUCountTotal, internal.EDUCountTotal, + ) + } v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() @@ -128,10 +139,10 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, + cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, ) }, - )).Methods(http.MethodPut, http.MethodOptions) + )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -241,7 +252,7 @@ func Setup( httpReq, federation, cfg, rsAPI, fsAPI, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryDirectoryRouteName) v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -250,13 +261,13 @@ func Setup( httpReq, userAPI, cfg, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryProfileRouteName) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, keyAPI, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) @@ -481,14 +492,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a146d85bd..82651719f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,26 +17,20 @@ package routing import ( "context" "encoding/json" - "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" - syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" ) const ( @@ -56,26 +50,6 @@ const ( MetricsWorkMissingPrevEvents = "missing_prev_events" ) -var ( - pduCountTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_pdus", - Help: "Number of incoming PDUs from remote servers with labels for success", - }, - []string{"status"}, // 'success' or 'total' - ) - eduCountTotal = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_edus", - Help: "Number of incoming EDUs from remote servers", - }, - ) -) - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} @@ -85,7 +59,7 @@ func Send( txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + keyAPI userAPI.FederationUserAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, mu *internal.MutexByRoom, @@ -123,18 +97,6 @@ func Send( defer close(ch) defer inFlightTxnsPerOrigin.Delete(index) - t := txnReq{ - rsAPI: rsAPI, - keys: keys, - ourServerName: cfg.Matrix.ServerName, - federation: federation, - servers: servers, - keyAPI: keyAPI, - roomsMu: mu, - producer: producer, - inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound, - } - var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` @@ -155,16 +117,23 @@ func Send( } } - // TODO: Really we should have a function to convert FederationRequest to txnReq - t.PDUs = txnEvents.PDUs - t.EDUs = txnEvents.EDUs - t.Origin = request.Origin() - t.TransactionID = txnID - t.Destination = cfg.Matrix.ServerName + t := internal.NewTxnReq( + rsAPI, + keyAPI, + cfg.Matrix.ServerName, + keys, + mu, + producer, + cfg.Matrix.Presence.EnableInbound, + txnEvents.PDUs, + txnEvents.EDUs, + request.Origin(), + txnID, + cfg.Matrix.ServerName) util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.ProcessTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -181,283 +150,3 @@ func Send( ch <- res return res } - -type txnReq struct { - gomatrixserverlib.Transaction - rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI - ourServerName gomatrixserverlib.ServerName - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - roomsMu *internal.MutexByRoom - servers federationAPI.ServersInRoomProvider - producer *producers.SyncAPIProducer - inboundPresenceEnabled bool -} - -// A subset of FederationClient functionality that txn requires. Useful for testing. -type txnFederationClient interface { - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, - ) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - - results := make(map[string]gomatrixserverlib.PDUResult) - roomVersions := make(map[string]gomatrixserverlib.RoomVersion) - getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { - if v, ok := roomVersions[roomID]; ok { - return v - } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) - return "" - } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion - } - - for _, pdu := range t.PDUs { - pduCountTotal.WithLabelValues("total").Inc() - var header struct { - RoomID string `json:"room_id"` - } - if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) - if err != nil { - if _, ok := err.(gomatrixserverlib.BadJSONError); ok { - // Room version 6 states that homeservers should strictly enforce canonical JSON - // on PDUs. - // - // This enforces that the entire transaction is rejected if a single bad PDU is - // sent. It is unclear if this is the correct behaviour or not. - // - // See https://github.com/matrix-org/synapse/issues/7543 - return nil, &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), - } - } - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) - continue - } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { - continue - } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: "Forbidden by server ACLs", - } - continue - } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - if err = api.SendEvents( - ctx, - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), - }, - t.Destination, - t.Origin, - api.DoNotSendToOtherServers, - nil, - true, - ); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - results[event.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - } - - wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -// nolint:gocyclo -func (t *txnReq) processEDUs(ctx context.Context) { - for _, e := range t.EDUs { - eduCountTotal.Inc() - switch e.Type { - case gomatrixserverlib.MTyping: - // https://matrix.org/docs/spec/server_server/latest#typing-notifications - var typingPayload struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - Typing bool `json:"typing"` - } - if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") - } - case gomatrixserverlib.MDirectToDevice: - // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema - var directPayload gomatrixserverlib.ToDeviceMessage - if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - for userID, byUser := range directPayload.Messages { - for deviceID, message := range byUser { - // TODO: check that the user and the device actually exist here - if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": directPayload.Sender, - "user_id": userID, - "device_id": deviceID, - }).Error("Failed to send send-to-device event to JetStream") - } - } - } - case gomatrixserverlib.MDeviceListUpdate: - if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") - } - case gomatrixserverlib.MReceipt: - // https://matrix.org/docs/spec/server_server/r0.1.4#receipts - payload := map[string]types.FederationReceiptMRead{} - - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") - continue - } - - for roomID, receipt := range payload { - for userID, mread := range receipt.User { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") - continue - } - if t.Origin != domain { - util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) - continue - } - if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": t.Origin, - "user_id": userID, - "room_id": roomID, - "events": mread.EventIDs, - }).Error("Failed to send receipt event to JetStream") - continue - } - } - } - case types.MSigningKeyUpdate: - if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - logrus.WithError(err).Errorf("Failed to process signing key update") - } - case gomatrixserverlib.MPresence: - if t.inboundPresenceEnabled { - if err := t.processPresence(ctx, e); err != nil { - logrus.WithError(err).Errorf("Failed to process presence update") - } - } - default: - util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") - } - } -} - -// processPresence handles m.receipt events -func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { - payload := types.Presence{} - if err := json.Unmarshal(e.Content, &payload); err != nil { - return err - } - for _, content := range payload.Push { - if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - presence, ok := syncTypes.PresenceFromString(content.Presence) - if !ok { - continue - } - if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { - return err - } - } - return nil -} - -// processReceiptEvent sends receipt events to JetStream -func (t *txnReq) processReceiptEvent(ctx context.Context, - userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, - eventIDs []string, -) error { - if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { - return nil - } else if serverName == t.ourServerName { - return nil - } else if serverName != t.Origin { - return nil - } - // store every event - for _, eventID := range eventIDs { - if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { - return fmt.Errorf("unable to set receipt event: %w", err) - } - } - - return nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b8bfe0221..eed4e7e69 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -1,552 +1,87 @@ -package routing +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test import ( - "context" + "encoding/hex" "encoding/json" - "fmt" + "net/http/httptest" "testing" - "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") ) -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) -) +type sendContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) +func TestHandleSend(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + content := sendContent{} + err := req.SetContent(content) if err != nil { - panic("cannot load test data: " + err.Error()) + t.Fatalf("Error: %s", err.Error()) } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) } - } + vars := map[string]string{"txnID": "1234"} + w := httptest.NewRecorder() + httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + assert.Equal(t, 200, res.StatusCode) + }) } - -type testRoomserverAPI struct { - api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse -} - -func (t *testRoomserverAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) - for _, ire := range request.InputRoomEvents { - fmt.Println("InputRoomEvents: ", ire.Event.EventID()) - } - return nil -} - -// Query the latest events and state for a room from the room server. -func (t *testRoomserverAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - r := t.queryLatestEventsAndState(request) - response.RoomExists = r.RoomExists - response.RoomVersion = testRoomVersion - response.LatestEvents = r.LatestEvents - response.StateEvents = r.StateEvents - response.Depth = r.Depth - return nil -} - -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryStateAfterEvents(request) - response.PrevEventsExist = res.PrevEventsExist - response.RoomExists = res.RoomExists - response.StateEvents = res.StateEvents - return nil -} - -// Query a list of events by event ID. -func (t *testRoomserverAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - res := t.queryEventsByID(request) - response.Events = res.Events - return nil -} - -// Query if a server is joined to a room -func (t *testRoomserverAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - response.RoomExists = true - response.IsInRoom = true - return nil -} - -// Asks for the room version for a given room. -func (t *testRoomserverAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil -} - -func (t *testRoomserverAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - res.Banned = false - return nil -} - -type txnFedClient struct { - state map[string]gomatrixserverlib.RespState // event_id to response - stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response - getEvent map[string]gomatrixserverlib.Transaction // event_id to response - getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, -) { - fmt.Println("testFederationClient.LookupState", eventID) - r, ok := c.state[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { - fmt.Println("testFederationClient.LookupStateIDs", eventID) - r, ok := c.stateIDs[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { - fmt.Println("testFederationClient.GetEvent", eventID) - r, ok := c.getEvent[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { - return c.getMissingEvents(missing) -} - -func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { - t := &txnReq{ - rsAPI: rsAPI, - keys: &test.NopJSONVerifier{}, - federation: fedClient, - roomsMu: internal.NewMutexByRoom(), - } - t.PDUs = pdus - t.Origin = testOrigin - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - t.Destination = testDestination - return t -} - -func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) - if err != nil { - t.Errorf("txn.processTransaction returned an error: %v", err) - return - } - if len(res.PDUs) != len(txn.PDUs) { - t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) - return - } -NextPDU: - for eventID, result := range res.PDUs { - if result.Error == "" { - continue - } - for _, eventIDWantError := range pdusWithErrors { - if eventID == eventIDWantError { - break NextPDU - } - } - t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) - } -} - -/* -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { -NextTuple: - for _, t := range tuples { - for _, o := range omitTuples { - if t == o { - break NextTuple - } - } - h, ok := testStateEvents[t] - if ok { - result = append(result, h) - } - } - return -} -*/ - -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { - for _, g := range got { - fmt.Println("GOT ", g.Event.EventID()) - } - if len(got) != len(want) { - t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) - return - } - for i := range got { - if got[i].Event.EventID() != want[i].EventID() { - t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) - } - } -} - -// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on -// to the roomserver. It's the most basic test possible. -func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver -// as it does the auth check. -func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{}) - // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, -// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, -// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in -// topological order and sent to the roomserver. -/* -func TestTransactionFetchMissingPrevEvents(t *testing.T) { - haveEvent := testEvents[len(testEvents)-3] - prevEvent := testEvents[len(testEvents)-2] - inputEvent := testEvents[len(testEvents)-1] - - var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions - rsAPI = &testRoomserverAPI{ - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - res := api.QueryEventsByIDResponse{} - for _, ev := range testEvents { - for _, id := range req.EventIDs { - if ev.EventID() == id { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - StateEvents: testEvents[:5], - } - }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: haveEvent.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - haveEvent.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, nil), - } - }, - } - - cli := &txnFedClient{ - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) - } - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - prevEvent.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - inputEvent.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) -} - -// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill -// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids -// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in -// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite -// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to -// continue. The DAG looks something like: -// FE GME TXN -// A ---> B ---> C ---> D -// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: -// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. -// - state resolution is done to check C is allowed. -// This results in B being sent as an outlier FIRST, then C,D. -func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { - eventA := testEvents[len(testEvents)-5] - // this is also len(testEvents)-4 - eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }] - eventC := testEvents[len(testEvents)-3] - eventD := testEvents[len(testEvents)-2] - fmt.Println("a:", eventA.EventID()) - fmt.Println("b:", eventB.EventID()) - fmt.Println("c:", eventC.EventID()) - fmt.Println("d:", eventD.EventID()) - var rsAPI *testRoomserverAPI - rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }, - } - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - omitTuples = nil // include event B now - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - var stateEvents []*gomatrixserverlib.HeaderedEvent - if prevEventExists { - stateEvents = fromStateTuples(req.StateToFetch, omitTuples) - } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventExists, - RoomExists: true, - StateEvents: stateEvents, - } - }, - - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - - var missingPrevEvent []string - if !prevEventExists { - missingPrevEvent = []string{"test"} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, - } - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: eventA.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - eventA.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, omitTuples), - } - }, - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - var res api.QueryEventsByIDResponse - fmt.Println("queryEventsByID ", req.EventIDs) - for _, wantEventID := range req.EventIDs { - for _, ev := range testStateEvents { - // roomserver is missing the power levels event unless it's been sent to us recently as an outlier - if wantEventID == eventB.EventID() { - fmt.Println("Asked for pl event") - for _, inEv := range rsAPI.inputRoomEvents { - fmt.Println("recv ", inEv.Event.EventID()) - if inEv.Event.EventID() == wantEventID { - res.Events = append(res.Events, inEv.Event) - break - } - } - continue - } - if ev.EventID() == wantEventID { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - } - // /state_ids for event B returns every state event but B (it's the state before) - var authEventIDs []string - var stateEventIDs []string - for _, ev := range testStateEvents { - if ev.EventID() == eventB.EventID() { - continue - } - // state res checks what auth events you give it, and this isn't a valid auth event - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { - authEventIDs = append(authEventIDs, ev.EventID()) - } - stateEventIDs = append(stateEventIDs, ev.EventID()) - } - cli := &txnFedClient{ - stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): { - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, - }, - }, - // /event for event B returns it - getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): { - PDUs: []json.RawMessage{ - eventB.JSON(), - }, - }, - }, - // /get_missing_events should be done exactly once - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) - } - // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - eventC.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - eventD.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) -} -*/ diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 2ba99112c..e29e3b140 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -28,25 +29,30 @@ type Statistics struct { // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 + + // How many times should we tolerate consecutive failures before we + // mark the destination as offline. At this point we should attempt + // to send messages to the user's async relay servers if we know them. + FailuresUntilAssumedOffline uint32 } -func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { +func NewStatistics( + db storage.Database, + failuresUntilBlacklist uint32, + failuresUntilAssumedOffline uint32, +) Statistics { return Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + FailuresUntilAssumedOffline: failuresUntilAssumedOffline, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { - // If the map hasn't been initialised yet then do that. - if s.servers == nil { - s.mutex.Lock() - s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics) - s.mutex.Unlock() - } // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] @@ -55,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS if !found { s.mutex.Lock() server = &ServerStatistics{ - statistics: s, - serverName: serverName, + statistics: s, + serverName: serverName, + knownRelayServers: []gomatrixserverlib.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -66,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) + } else { + server.assumedOffline.Store(assumedOffline) + } + + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) + } else { + server.relayMutex.Lock() + server.knownRelayServers = knownRelayServers + server.relayMutex.Unlock() + } } return server } +type SendMethod uint8 + +const ( + SendDirect SendMethod = iota + SendViaRelay +) + // ServerStatistics contains information about our interactions with a // remote federated host, e.g. how many times we were successful, how // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes - notifierMutex sync.Mutex + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex + knownRelayServers []gomatrixserverlib.ServerName + relayMutex sync.Mutex } const maxJitterMultiplier = 1.4 @@ -118,14 +150,22 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. -func (s *ServerStatistics) Success() { +// `relay` specifies whether the success was to the actual destination +// or one of their relay servers. +func (s *ServerStatistics) Success(method SendMethod) { s.cancel() s.backoffCount.Store(0) - s.successCounter.Inc() - if s.statistics.DB != nil { - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + // NOTE : Sending to the final destination vs. a relay server has + // slightly different semantics. + if method == SendDirect { + s.successCounter.Inc() + if s.blacklisted.Load() && s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } + + s.removeAssumedOffline() } } @@ -144,7 +184,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { - if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + backoffCount := s.backoffCount.Inc() + + if backoffCount >= s.statistics.FailuresUntilAssumedOffline { + s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } + } + + if backoffCount >= s.statistics.FailuresUntilBlacklist { s.blacklisted.Store(true) if s.statistics.DB != nil { if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { @@ -162,13 +213,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { s.backoffUntil.Store(until) s.statistics.backoffMutex.Lock() - defer s.statistics.backoffMutex.Unlock() s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) + s.statistics.backoffMutex.Unlock() } return s.backoffUntil.Load().(time.Time), false } +// MarkServerAlive removes the assumed offline and blacklisted statuses from this server. +// Returns whether the server was blacklisted before this point. +func (s *ServerStatistics) MarkServerAlive() bool { + s.removeAssumedOffline() + wasBlacklisted := s.removeBlacklist() + return wasBlacklisted +} + // ClearBackoff stops the backoff timer for this destination if it is running // and removes the timer from the backoffTimers map. func (s *ServerStatistics) ClearBackoff() { @@ -196,13 +255,13 @@ func (s *ServerStatistics) backoffFinished() { } // BackoffInfo returns information about the current or previous backoff. -// Returns the last backoffUntil time and whether the server is currently blacklisted or not. -func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { +// Returns the last backoffUntil time. +func (s *ServerStatistics) BackoffInfo() *time.Time { until, ok := s.backoffUntil.Load().(time.Time) if ok { - return &until, s.blacklisted.Load() + return &until } - return nil, s.blacklisted.Load() + return nil } // Blacklisted returns true if the server is blacklisted and false @@ -211,10 +270,33 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { +// AssumedOffline returns true if the server is assumed offline and false +// otherwise. +func (s *ServerStatistics) AssumedOffline() bool { + return s.assumedOffline.Load() +} + +// removeBlacklist removes the blacklisted status from the server. +// Returns whether the server was blacklisted. +func (s *ServerStatistics) removeBlacklist() bool { + var wasBlacklisted bool + + if s.Blacklisted() { + wasBlacklisted = true + _ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName) + } s.cancel() s.backoffCount.Store(0) + + return wasBlacklisted +} + +// removeAssumedOffline removes the assumed offline status from the server. +func (s *ServerStatistics) removeAssumedOffline() { + if s.AssumedOffline() { + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) + } + s.assumedOffline.Store(false) } // SuccessCount returns the number of successful requests. This is @@ -222,3 +304,46 @@ func (s *ServerStatistics) RemoveBlacklist() { func (s *ServerStatistics) SuccessCount() uint32 { return s.successCounter.Load() } + +// KnownRelayServers returns the list of relay servers associated with this +// server. +func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { + s.relayMutex.Lock() + defer s.relayMutex.Unlock() + return s.knownRelayServers +} + +func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { + seenSet := make(map[gomatrixserverlib.ServerName]bool) + uniqueList := []gomatrixserverlib.ServerName{} + for _, srv := range relayServers { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) + if err != nil { + logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) + return + } + + for _, newServer := range uniqueList { + alreadyKnown := false + knownRelayServers := s.KnownRelayServers() + for _, srv := range knownRelayServers { + if srv == newServer { + alreadyKnown = true + } + } + if !alreadyKnown { + { + s.relayMutex.Lock() + s.knownRelayServers = append(s.knownRelayServers, newServer) + s.relayMutex.Unlock() + } + } + } +} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 6aa997f44..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -4,17 +4,26 @@ import ( "math" "testing" "time" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 ) func TestBackoff(t *testing.T) { - stats := NewStatistics(nil, 7) + stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{ statistics: &stats, serverName: "test.com", } // Start by checking that counting successes works. - server.Success() + server.Success(SendDirect) if successes := server.SuccessCount(); successes != 1 { t.Fatalf("Expected success count 1, got %d", successes) } @@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) { // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - - // Get the duration. - _, blacklist := server.BackoffInfo() + blacklist := server.Blacklisted() + assumedOffline := server.AssumedOffline() duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that @@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) { server.cancel() server.backoffStarted.Store(false) + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i) + } + } + + // Check if we should be assumed offline by now. + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i) + } else { + t.Logf("Backoff %d is assumed offline as expected", i) + } + } else { + if assumedOffline { + t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i) + } else { + t.Logf("Backoff %d is not assumed offline as expected", i) + } + } + // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffInfo and Failure returned different blacklist values") + t.Fatalf("Blacklisted and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue } + } else { + if blacklist { + t.Fatalf("Backoff %d should not have resulted in blacklist but did", i) + } else { + t.Logf("Backoff %d is not blacklisted as expected", i) + } } // Check if the duration is what we expect. @@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) { } } } + +func TestRelayServersListing(t *testing.T) { + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + server := ServerStatistics{statistics: &stats} + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers := server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers = server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) +} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 276cd9a50..4f5300af1 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,11 +20,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) type Database interface { + P2PDatabase gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) @@ -34,16 +35,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) @@ -54,6 +55,18 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. + // If the server already exists in the table, nothing happens and returns success. + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Removes the server from the list of assumed offline servers. + // If the server doesn't exist in the table, nothing happens and returns success. + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Purges all entries from the assumed offline table. + RemoveAllServersAssumedOffline(ctx context.Context) error + // Gets whether the provided server is present in the table. + // If it is present, returns true. If not, returns false. + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) @@ -71,4 +84,24 @@ type Database interface { GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error + + PurgeRoom(ctx context.Context, roomID string) error +} + +type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error } diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..5695d2e54 --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go new file mode 100644 index 000000000..f7267978f --- /dev/null +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + deleteRelayServersStmt *sql.Stmt + deleteAllRelayServersStmt *sql.Stmt +} + +func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteRelayServersStmt, deleteRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index fe84e932e..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewPostgresRelayServersTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..b347269c1 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1e1ea9e17..6769637bc 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -37,6 +38,8 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -44,22 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -149,18 +143,21 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) + }) +} + +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) +} + +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) +} + +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) + }) +} + +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) @@ -259,3 +373,18 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge joined hosts: %w", err) + } + if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge inbound peeks: %w", err) + } + if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge outbound peeks: %w", err) + } + return nil + }) +} diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index be8355f31..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index da4cb979d..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..ff2afb4da --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go new file mode 100644 index 000000000..27c3cca2c --- /dev/null +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic + deleteAllRelayServersStmt *sql.Stmt +} + +func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index d13b5defc..1e7e41a2c 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewSQLiteRelayServersTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 5b57d40d4..1d2a13e81 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -6,14 +6,13 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) { assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } + +func TestServersAssumedOffline(t *testing.T) { + server1 := gomatrixserverlib.ServerName("server1") + server2 := gomatrixserverlib.ServerName("server2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + // Set server1 & server2 as assumed offline. + err := db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + err = db.SetServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + + // Ensure both servers are assumed offline. + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Set server1 as not assumed offline. + err = db.RemoveServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Re-set server1 as assumed offline. + err = db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure server1 is assumed offline. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + + err = db.RemoveAllServersAssumedOffline(context.Background()) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.False(t, isOffline) + }) +} + +func TestRelayServersStored(t *testing.T) { + server := gomatrixserverlib.ServerName("server") + relayServer1 := gomatrixserverlib.ServerName("relayserver1") + relayServer2 := gomatrixserverlib.ServerName("relayserver2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + assert.Equal(t, relayServer2, relayServers[1]) + + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2b36edb46..762504e45 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -49,6 +49,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error @@ -66,6 +79,20 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + +type FederationRelayServers interface { + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go new file mode 100644 index 000000000..b41211551 --- /dev/null +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -0,0 +1,224 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" + server4 = "server4" +) + +type RelayServersDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationRelayServers +} + +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationRelayServers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayServersTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayServersTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayServersDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + +func TestShouldDeleteCorrectRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } + + expectedRelayServers := []gomatrixserverlib.ServerName{server3} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldDeleteAllRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + + expectedRelayServers1 := []gomatrixserverlib.ServerName{} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} diff --git a/go.mod b/go.mod index 2d7174150..6857290f3 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/blevesearch/bleve/v2 v2.3.4 + github.com/blevesearch/bleve/v2 v2.3.6 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible @@ -22,9 +22,9 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab - github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 - github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 + github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a + github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 @@ -37,17 +37,16 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.1.0 - golang.org/x/image v0.1.0 + golang.org/x/crypto v0.5.0 + golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/net v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/term v0.5.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -58,24 +57,24 @@ require ( require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect - github.com/RoaringBitmap/roaring v1.2.1 // indirect + github.com/RoaringBitmap/roaring v1.2.3 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.3.3 // indirect - github.com/blevesearch/bleve_index_api v1.0.3 // indirect - github.com/blevesearch/geo v0.1.14 // indirect + github.com/bits-and-blooms/bitset v1.5.0 // indirect + github.com/blevesearch/bleve_index_api v1.0.5 // indirect + github.com/blevesearch/geo v0.1.17 // indirect github.com/blevesearch/go-porterstemmer v1.0.3 // indirect github.com/blevesearch/gtreap v0.1.1 // indirect github.com/blevesearch/mmap-go v1.0.4 // indirect - github.com/blevesearch/scorch_segment_api/v2 v2.1.2 // indirect - github.com/blevesearch/segment v0.9.0 // indirect + github.com/blevesearch/scorch_segment_api/v2 v2.1.4 // indirect + github.com/blevesearch/segment v0.9.1 // indirect github.com/blevesearch/snowballstem v0.9.0 // indirect - github.com/blevesearch/upsidedown_store_api v1.0.1 // indirect - github.com/blevesearch/vellum v1.0.8 // indirect - github.com/blevesearch/zapx/v11 v11.3.5 // indirect - github.com/blevesearch/zapx/v12 v12.3.5 // indirect - github.com/blevesearch/zapx/v13 v13.3.5 // indirect - github.com/blevesearch/zapx/v14 v14.3.5 // indirect - github.com/blevesearch/zapx/v15 v15.3.5 // indirect + github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect + github.com/blevesearch/vellum v1.0.9 // indirect + github.com/blevesearch/zapx/v11 v11.3.7 // indirect + github.com/blevesearch/zapx/v12 v12.3.7 // indirect + github.com/blevesearch/zapx/v13 v13.3.7 // indirect + github.com/blevesearch/zapx/v14 v14.3.7 // indirect + github.com/blevesearch/zapx/v15 v15.3.8 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/distribution v2.8.1+incompatible // indirect @@ -95,9 +94,6 @@ require ( github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.15.11 // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/lucas-clemente/quic-go v0.30.0 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.3 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/minio/highwayhash v1.0.2 // indirect @@ -117,14 +113,19 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/quic-go/qtls-go1-18 v0.2.0 // indirect + github.com/quic-go/qtls-go1-19 v0.2.0 // indirect + github.com/quic-go/qtls-go1-20 v0.1.0 // indirect + github.com/quic-go/quic-go v0.32.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect - golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index b12f65eab..2272e5404 100644 --- a/go.sum +++ b/go.sum @@ -50,9 +50,8 @@ github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0 github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= -github.com/RoaringBitmap/roaring v0.9.4/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= -github.com/RoaringBitmap/roaring v1.2.1 h1:58/LJlg/81wfEHd5L9qsHduznOIhyv4qb1yWcSvVq9A= -github.com/RoaringBitmap/roaring v1.2.1/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= +github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= +github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -70,51 +69,45 @@ github.com/anacrolix/missinggo v1.2.1 h1:0IE3TqX5y5D0IxeMwTyIgqdDew4QrzcXaaEnJQy github.com/anacrolix/missinggo v1.2.1/go.mod h1:J5cMhif8jPmFoC3+Uvob3OXXNIhOUikzMt+uUjeM21Y= github.com/anacrolix/missinggo/perf v1.0.0/go.mod h1:ljAFWkBuzkO12MQclXzZrosP5urunoLS0Cbvb4V0uMQ= github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/bits-and-blooms/bitset v1.3.3 h1:R1XWiopGiXf66xygsiLpzLo67xEYvMkHw3w+rCOSAwg= -github.com/bits-and-blooms/bitset v1.3.3/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/blevesearch/bleve/v2 v2.3.4 h1:SSb7/cwGzo85LWX1jchIsXM8ZiNNMX3shT5lROM63ew= -github.com/blevesearch/bleve/v2 v2.3.4/go.mod h1:Ot0zYum8XQRfPcwhae8bZmNyYubynsoMjVvl1jPqL30= -github.com/blevesearch/bleve_index_api v1.0.3 h1:DDSWaPXOZZJ2BB73ZTWjKxydAugjwywcqU+91AAqcAg= -github.com/blevesearch/bleve_index_api v1.0.3/go.mod h1:fiwKS0xLEm+gBRgv5mumf0dhgFr2mDgZah1pqv1c1M4= -github.com/blevesearch/geo v0.1.13/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/geo v0.1.14 h1:TTDpJN6l9ck/cUYbXSn4aCElNls0Whe44rcQKsB7EfU= -github.com/blevesearch/geo v0.1.14/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= +github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= +github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= +github.com/blevesearch/bleve/v2 v2.3.6 h1:NlntUHcV5CSWIhpugx4d/BRMGCiaoI8ZZXrXlahzNq4= +github.com/blevesearch/bleve/v2 v2.3.6/go.mod h1:JM2legf1cKVkdV8Ehu7msKIOKC0McSw0Q16Fmv9vsW4= +github.com/blevesearch/bleve_index_api v1.0.5 h1:Lc986kpC4Z0/n1g3gg8ul7H+lxgOQPcXb9SxvQGu+tw= +github.com/blevesearch/bleve_index_api v1.0.5/go.mod h1:YXMDwaXFFXwncRS8UobWs7nvo0DmusriM1nztTlj1ms= +github.com/blevesearch/geo v0.1.17 h1:AguzI6/5mHXapzB0gE9IKWo+wWPHZmXZoscHcjFgAFA= +github.com/blevesearch/geo v0.1.17/go.mod h1:uRMGWG0HJYfWfFJpK3zTdnnr1K+ksZTuWKhXeSokfnM= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M= -github.com/blevesearch/goleveldb v1.0.1/go.mod h1:WrU8ltZbIp0wAoig/MHbrPCXSOLpe79nz5lv5nqfYrQ= github.com/blevesearch/gtreap v0.1.1 h1:2JWigFrzDMR+42WGIN/V2p0cUvn4UP3C4Q5nmaZGW8Y= github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgYICSZ3w0tYk= -github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2 h1:TAte9VZLWda5WAVlZTTZ+GCzEHqGJb4iB2aiZSA6Iv8= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2/go.mod h1:rvoQXZGq8drq7vXbNeyiRzdEOwZkjkiYGf1822i6CRA= -github.com/blevesearch/segment v0.9.0 h1:5lG7yBCx98or7gK2cHMKPukPZ/31Kag7nONpoBt22Ac= -github.com/blevesearch/segment v0.9.0/go.mod h1:9PfHYUdQCgHktBgvtUOF4x+pc4/l8rdH0u5spnW85UQ= -github.com/blevesearch/snowball v0.6.1/go.mod h1:ZF0IBg5vgpeoUhnMza2v0A/z8m1cWPlwhke08LpNusg= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4 h1:LmGmo5twU3gV+natJbKmOktS9eMhokPGKWuR+jX84vk= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4/go.mod h1:PgVnbbg/t1UkgezPDu8EHLi1BHQ17xUwsFdU6NnOYS0= +github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= +github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s= github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs= -github.com/blevesearch/upsidedown_store_api v1.0.1 h1:1SYRwyoFLwG3sj0ed89RLtM15amfX2pXlYbFOnF8zNU= -github.com/blevesearch/upsidedown_store_api v1.0.1/go.mod h1:MQDVGpHZrpe3Uy26zJBf/a8h0FZY6xJbthIMm8myH2Q= -github.com/blevesearch/vellum v1.0.8 h1:iMGh4lfxza4BnWO/UJTMPlI3HsK9YawjPv+TteVa9ck= -github.com/blevesearch/vellum v1.0.8/go.mod h1:+cpRi/tqq49xUYSQN2P7A5zNSNrS+MscLeeaZ3J46UA= -github.com/blevesearch/zapx/v11 v11.3.5 h1:eBQWQ7huA+mzm0sAGnZDwgGGli7S45EO+N+ObFWssbI= -github.com/blevesearch/zapx/v11 v11.3.5/go.mod h1:5UdIa/HRMdeRCiLQOyFESsnqBGiip7vQmYReA9toevU= -github.com/blevesearch/zapx/v12 v12.3.5 h1:5pX2hU+R1aZihT7ac1dNWh1n4wqkIM9pZzWp0ANED9s= -github.com/blevesearch/zapx/v12 v12.3.5/go.mod h1:ANcthYRZQycpbRut/6ArF5gP5HxQyJqiFcuJCBju/ss= -github.com/blevesearch/zapx/v13 v13.3.5 h1:eJ3gbD+Nu8p36/O6lhfdvWQ4pxsGYSuTOBrLLPVWJ74= -github.com/blevesearch/zapx/v13 v13.3.5/go.mod h1:FV+dRnScFgKnRDIp08RQL4JhVXt1x2HE3AOzqYa6fjo= -github.com/blevesearch/zapx/v14 v14.3.5 h1:hEvVjZaagFCvOUJrlFQ6/Z6Jjy0opM3g7TMEo58TwP4= -github.com/blevesearch/zapx/v14 v14.3.5/go.mod h1:954A/eKFb+pg/ncIYWLWCKY+mIjReM9FGTGIO2Wu1cU= -github.com/blevesearch/zapx/v15 v15.3.5 h1:NVD0qq8vRk66ImJn1KloXT5ckqPDUZT7VbVJs9jKlac= -github.com/blevesearch/zapx/v15 v15.3.5/go.mod h1:QMUh2hXCaYIWFKPYGavq/Iga2zbHWZ9DZAa9uFbWyvg= +github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A= +github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= +github.com/blevesearch/vellum v1.0.9 h1:PL+NWVk3dDGPCV0hoDu9XLLJgqU4E5s/dOeEJByQ2uQ= +github.com/blevesearch/vellum v1.0.9/go.mod h1:ul1oT0FhSMDIExNjIxHqJoGpVrBpKCdgDQNxfqgJt7k= +github.com/blevesearch/zapx/v11 v11.3.7 h1:Y6yIAF/DVPiqZUA/jNgSLXmqewfzwHzuwfKyfdG+Xaw= +github.com/blevesearch/zapx/v11 v11.3.7/go.mod h1:Xk9Z69AoAWIOvWudNDMlxJDqSYGf90LS0EfnaAIvXCA= +github.com/blevesearch/zapx/v12 v12.3.7 h1:DfQ6rsmZfEK4PzzJJRXjiM6AObG02+HWvprlXQ1Y7eI= +github.com/blevesearch/zapx/v12 v12.3.7/go.mod h1:SgEtYIBGvM0mgIBn2/tQE/5SdrPXaJUaT/kVqpAPxm0= +github.com/blevesearch/zapx/v13 v13.3.7 h1:igIQg5eKmjw168I7av0Vtwedf7kHnQro/M+ubM4d2l8= +github.com/blevesearch/zapx/v13 v13.3.7/go.mod h1:yyrB4kJ0OT75UPZwT/zS+Ru0/jYKorCOOSY5dBzAy+s= +github.com/blevesearch/zapx/v14 v14.3.7 h1:gfe+fbWslDWP/evHLtp/GOvmNM3sw1BbqD7LhycBX20= +github.com/blevesearch/zapx/v14 v14.3.7/go.mod h1:9J/RbOkqZ1KSjmkOes03AkETX7hrXT0sFMpWH4ewC4w= +github.com/blevesearch/zapx/v15 v15.3.8 h1:q4uMngBHzL1IIhRc8AJUEkj6dGOE3u1l3phLu7hq8uk= +github.com/blevesearch/zapx/v15 v15.3.8/go.mod h1:m7Y6m8soYUvS7MjN9eKlz1xrLCcmqfFadmu7GhWIrLY= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 h1:GKTyiRCL6zVf5wWaqKnf+7Qs6GbEPfd4iMOitWzXJx8= @@ -130,12 +123,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI= github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= -github.com/couchbase/moss v0.2.0/go.mod h1:9MaHIaRuy9pvLPUJxB8sh8OrLfyDczECVL37grCIubs= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -164,7 +151,6 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/sentry-go v0.14.0 h1:rlOBkuFZRKKdUnKO+0U3JclRDQKlRu5vVQtkWSQvC70= github.com/getsentry/sentry-go v0.14.0/go.mod h1:RZPJKSw+adu8PBNygiri/A98FqVr2HtRckJk9XVxJ9I= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -241,7 +227,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= @@ -261,7 +246,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -289,15 +273,11 @@ github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0 h1:pO2K/gKgKaat5LdpAhxhluX2GPQMaI3W5FUz/I/UnWk= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/json-iterator/go v0.0.0-20171115153421-f7279a603ede/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -335,25 +315,18 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= -github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= -github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= -github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 h1:JSw0nmjMrgBmoM2aQsa78LTpI5BnuD9+vOiEQ4Qo0qw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= +github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -365,8 +338,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182aff github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae h1:O4SWKdcHVCvYqyDV+9CJA1fcDN2L11Bule0iFy3YlAI= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -400,11 +371,8 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 h1:Dmx8g2747UTVPzSkmohk84S3g/uWqd6+f4SSLPhLcfA= github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigBfFfHDWsklmo0T7Ixbg0XXgck+Hq4O9k= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.22.1 h1:pY8O4lBfsHKZHM/6nrxkhVPUznOlIu3quZcKP/M20KI= github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -415,8 +383,6 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= @@ -452,13 +418,20 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= +github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U= +github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= +github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= +github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= +github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= +github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= +github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa h1:tEkEyxYeZ43TR55QU/hsIt9aRGBxbgGuz9CGykjvogY= github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -468,12 +441,7 @@ github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0 github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -490,12 +458,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -505,11 +474,9 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yggdrasil-network/yggdrasil-go v0.4.6 h1:GALUDV9QPz/5FVkbazpkTc9EABHufA556JwUJZr41j4= github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA9QgCS6Bu+9V+1VkWM84wpw= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -518,7 +485,6 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -530,7 +496,6 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -539,8 +504,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -554,13 +519,13 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.1.0 h1:r8Oj8ZA2Xy12/b5KZYj3tuv7NG/fBz3TwQVvpJ9l8Rk= -golang.org/x/image v0.1.0/go.mod h1:iyPr49SD/G/TBxYVB/9RRtGUT5eNbo2u4NamWeQcD5c= +golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= +golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -587,7 +552,6 @@ golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -621,8 +585,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -644,10 +608,7 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181221143128-b4a75ba826a6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -697,12 +658,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -710,8 +671,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -859,13 +820,11 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/h2non/bimg.v1 v1.1.9 h1:wZIUbeOnwr37Ta4aofhIv8OI8v4ujpjXC9mXnAGpQjM= gopkg.in/h2non/bimg.v1 v1.1.9/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/helm/dendrite/.helm-docs/monitoring.gotmpl b/helm/dendrite/.helm-docs/monitoring.gotmpl new file mode 100644 index 000000000..3618a1c1a --- /dev/null +++ b/helm/dendrite/.helm-docs/monitoring.gotmpl @@ -0,0 +1,22 @@ +{{ define "chart.monitoringSection" }} +## Monitoring + +[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) + +* Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: +```yaml +prometheus: + servicemonitor: + enabled: true + labels: + release: "kube-prometheus-stack" + rules: + enabled: true # will deploy alert rules + labels: + release: "kube-prometheus-stack" +grafana: + dashboards: + enabled: true # will deploy default dashboards +``` +PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 15d1e6d19..dc2764939 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.8" -appVersion: "0.10.8" +version: "0.11.2" +appVersion: "0.11.1" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index cb850d655..51587b766 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,6 +1,6 @@ # dendrite -![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +![Version: 0.11.2](https://img.shields.io/badge/Version-0.11.2-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.11.1](https://img.shields.io/badge/AppVersion-0.11.1-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** @@ -41,8 +41,9 @@ Create a folder `appservices` and place your configurations in there. The confi | Key | Type | Default | Description | |-----|------|---------|-------------| -| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.repository | string | `"ghcr.io/matrix-org/dendrite-monolith"` | Docker repository/image to use | | image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | | signing_key.create | bool | `true` | Create a new signing key, if not exists | | signing_key.existingSecret | string | `""` | Use an existing secret | | resources | object | sets some sane default values | Default resource requests/limits. | @@ -144,4 +145,36 @@ Create a folder `appservices` and place your configurations in there. The confi | ingress.annotations | object | `{}` | Extra, custom annotations | | ingress.tls | list | `[]` | | | service.type | string | `"ClusterIP"` | | -| service.port | int | `80` | | +| service.port | int | `8008` | | +| prometheus.servicemonitor.enabled | bool | `false` | Enable ServiceMonitor for Prometheus-Operator for scrape metric-endpoint | +| prometheus.servicemonitor.labels | object | `{}` | Extra Labels on ServiceMonitor for selector of Prometheus Instance | +| prometheus.rules.enabled | bool | `false` | Enable PrometheusRules for Prometheus-Operator for setup alerting | +| prometheus.rules.labels | object | `{}` | Extra Labels on PrometheusRules for selector of Prometheus Instance | +| prometheus.rules.additionalRules | list | `[]` | additional alertrules (no default alertrules are provided) | +| grafana.dashboards.enabled | bool | `false` | | +| grafana.dashboards.labels | object | `{"grafana_dashboard":"1"}` | Extra Labels on ConfigMap for selector of grafana sidecar | +| grafana.dashboards.annotations | object | `{}` | Extra Annotations on ConfigMap additional config in grafana sidecar | + +## Monitoring + +[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) + +* Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: +```yaml +prometheus: + servicemonitor: + enabled: true + labels: + release: "kube-prometheus-stack" + rules: + enabled: true # will deploy alert rules + labels: + release: "kube-prometheus-stack" +grafana: + dashboards: + enabled: true # will deploy default dashboards +``` +PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) + +---------------------------------------------- +Autogenerated from chart metadata using [helm-docs vv1.11.0](https://github.com/norwoodj/helm-docs/releases/vv1.11.0) \ No newline at end of file diff --git a/helm/dendrite/README.md.gotmpl b/helm/dendrite/README.md.gotmpl index 7c32f7b02..9411733ce 100644 --- a/helm/dendrite/README.md.gotmpl +++ b/helm/dendrite/README.md.gotmpl @@ -10,4 +10,5 @@ {{ template "chart.sourcesSection" . }} {{ template "chart.requirementsSection" . }} {{ template "chart.valuesSection" . }} +{{ template "chart.monitoringSection" . }} {{ template "helm-docs.versionFooter" . }} \ No newline at end of file diff --git a/helm/dendrite/ci/ct-ingress-values.yaml b/helm/dendrite/ci/ct-ingress-values.yaml index 28311d33e..f3f58b5ca 100644 --- a/helm/dendrite/ci/ct-ingress-values.yaml +++ b/helm/dendrite/ci/ct-ingress-values.yaml @@ -11,3 +11,8 @@ dendrite_config: ingress: enabled: true + +# dashboard is an ConfigMap with labels - it does not harm on testing +grafana: + dashboards: + enabled: true diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev1.json b/helm/dendrite/grafana_dashboards/dendrite-rev1.json new file mode 100644 index 000000000..206e8af87 --- /dev/null +++ b/helm/dendrite/grafana_dashboards/dendrite-rev1.json @@ -0,0 +1,1119 @@ +{ + "__inputs": [ + { + "name": "DS_INFLUXDB_DOMOTICA", + "label": "", + "description": "", + "type": "datasource", + "pluginId": "influxdb", + "pluginName": "InfluxDB" + }, + { + "name": "DS_PROMETHEUS", + "label": "Prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "7.4.2" + }, + { + "type": "panel", + "id": "graph", + "name": "Graph", + "version": "" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "influxdb", + "name": "InfluxDB", + "version": "1.0.0" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "stat", + "name": "Stat", + "version": "" + } + ], + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Dendrite dashboard from https://github.com/matrix-org/dendrite/", + "editable": true, + "gnetId": 13916, + "graphTooltip": 0, + "id": null, + "iteration": 1613683251329, + "links": [], + "panels": [ + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 4, + "panels": [], + "title": "Overview", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 5, + "w": 10, + "x": 0, + "y": 1 + }, + "hiddenSeries": false, + "id": 2, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "rate(process_cpu_seconds_total{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{job}}-{{index}} ", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "CPU usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "decimals": null, + "format": "percentunit", + "label": null, + "logBase": 1, + "max": "1", + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "datasource": "${DS_PROMETHEUS}", + "description": "Total number of registered users", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 10, + "y": 1 + }, + "id": 20, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "7.4.2", + "targets": [ + { + "exemplar": false, + "expr": "dendrite_clientapi_reg_users_total", + "instant": false, + "interval": "", + "legendFormat": "Users", + "refId": "A" + } + ], + "title": "Registerd Users", + "type": "stat" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "The number of sync requests that are active right now and are waiting to be woken by a notifier", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 5, + "w": 10, + "x": 14, + "y": 1 + }, + "hiddenSeries": false, + "id": 6, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(dendrite_syncapi_active_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", + "hide": false, + "interval": "", + "legendFormat": "active", + "refId": "A" + }, + { + "expr": "sum(rate(dendrite_syncapi_waiting_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", + "hide": false, + "interval": "", + "legendFormat": "waiting", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Sync API", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:232", + "format": "hertz", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "$$hashKey": "object:233", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolateOranges", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "tsbuckets", + "datasource": "${DS_PROMETHEUS}", + "description": "How long it takes to build and submit a new event from the client API to the roomserver", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 24, + "x": 0, + "y": 6 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 24, + "legend": { + "show": false + }, + "pluginVersion": "7.4.2", + "reverseYBuckets": false, + "targets": [ + { + "expr": "dendrite_clientapi_sendevent_duration_millis_bucket{action=\"build\",instance=\"$instance\"}", + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Sendevent Duration", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "s", + "logBase": 1, + "max": null, + "min": "0", + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 11 + }, + "id": 8, + "panels": [], + "title": "Federation", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "Collection of queues for sending transactions to other matrix servers", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 6, + "w": 24, + "x": 0, + "y": 12 + }, + "hiddenSeries": false, + "id": 10, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_federationsender_destination_queues_running", + "interval": "", + "legendFormat": "Queue Running", + "refId": "A" + }, + { + "expr": "dendrite_federationsender_destination_queues_total", + "hide": false, + "interval": "", + "legendFormat": "Queue Total", + "refId": "B" + }, + { + "expr": "dendrite_federationsender_destination_queues_backing_off", + "hide": false, + "interval": "", + "legendFormat": "Backing Off", + "refId": "C" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Federation Sender Destination", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:443", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:444", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 18 + }, + "id": 26, + "panels": [], + "title": "Rooms", + "type": "row" + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolateOranges", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "timeseries", + "datasource": "${DS_PROMETHEUS}", + "description": "How long it takes the roomserver to process an event", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 19 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 28, + "legend": { + "show": false + }, + "pluginVersion": "7.4.2", + "reverseYBuckets": false, + "targets": [ + { + "expr": "sum(rate(dendrite_roomserver_processroomevent_duration_millis_bucket{instance=\"$instance\"}[$bucket_size])) by (le)", + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Room Event Processing", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "s", + "logBase": 1, + "max": null, + "min": null, + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 26 + }, + "id": 12, + "panels": [], + "title": "Caches", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 27 + }, + "hiddenSeries": false, + "id": 14, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_server_key", + "interval": "", + "legendFormat": "Server keys", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Server Keys", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:667", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:668", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 27 + }, + "hiddenSeries": false, + "id": 16, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_federation_event", + "interval": "", + "legendFormat": "Federation Event", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Federation Events", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:784", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:785", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 16, + "y": 27 + }, + "hiddenSeries": false, + "id": 18, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_roomserver_room_ids", + "interval": "", + "legendFormat": "Room IDs", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Room IDs", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:898", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:899", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "10s", + "schemaVersion": 27, + "style": "dark", + "tags": [ + "matrix", + "dendrite" + ], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "Prometheus", + "value": "Prometheus" + }, + "description": null, + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "datasource", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "auto": true, + "auto_count": 100, + "auto_min": "30s", + "current": { + "selected": false, + "text": "auto", + "value": "$__auto_interval_bucket_size" + }, + "description": null, + "error": null, + "hide": 0, + "label": "Bucket Size", + "name": "bucket_size", + "options": [ + { + "selected": true, + "text": "auto", + "value": "$__auto_interval_bucket_size" + }, + { + "selected": false, + "text": "30s", + "value": "30s" + }, + { + "selected": false, + "text": "1m", + "value": "1m" + }, + { + "selected": false, + "text": "2m", + "value": "2m" + }, + { + "selected": false, + "text": "5m", + "value": "5m" + }, + { + "selected": false, + "text": "10m", + "value": "10m" + }, + { + "selected": false, + "text": "15m", + "value": "15m" + } + ], + "query": "30s,1m,2m,5m,10m,15m", + "queryValue": "", + "refresh": 2, + "skipUrlSync": false, + "type": "interval" + }, + { + "allValue": null, + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", + "description": null, + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "instance", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": null, + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": "Job", + "multi": true, + "name": "job", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 1, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".*", + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": true, + "name": "index", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 3, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-3h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Dendrite", + "uid": "RoRt1jEGz", + "version": 8 +} \ No newline at end of file diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl index 291f351bc..026706588 100644 --- a/helm/dendrite/templates/_helpers.tpl +++ b/helm/dendrite/templates/_helpers.tpl @@ -15,9 +15,11 @@ {{- define "image.name" -}} -image: {{ .name }} +{{- with .Values.image -}} +image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }} imagePullPolicy: {{ .pullPolicy }} {{- end -}} +{{- end -}} {{/* Expand the name of the chart. diff --git a/helm/dendrite/templates/configmap_grafana_dashboards.yaml b/helm/dendrite/templates/configmap_grafana_dashboards.yaml new file mode 100644 index 000000000..e2abc4909 --- /dev/null +++ b/helm/dendrite/templates/configmap_grafana_dashboards.yaml @@ -0,0 +1,16 @@ +{{- if .Values.grafana.dashboards.enabled }} +{{- range $path, $bytes := .Files.Glob "grafana_dashboards/*" }} +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "dendrite.fullname" $ }}-grafana-dashboards-{{ base $path }} + labels: + {{- include "dendrite.labels" $ | nindent 4 }} + {{- toYaml $.Values.grafana.dashboards.labels | nindent 4 }} + annotations: + {{- toYaml $.Values.grafana.dashboards.annotations | nindent 4 }} +data: + {{- ($.Files.Glob $path ).AsConfig | nindent 2 }} +{{- end }} +{{- end }} diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml index 629ffe528..b463c7d0b 100644 --- a/helm/dendrite/templates/deployment.yaml +++ b/helm/dendrite/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: persistentVolumeClaim: claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} containers: - - name: {{ $.Chart.Name }} - {{- include "image.name" $.Values.image | nindent 8 }} + - name: {{ .Chart.Name }} + {{- include "image.name" . | nindent 8 }} args: - '--config' - '/etc/dendrite/dendrite.yaml' diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml index 76915694d..c10f358b0 100644 --- a/helm/dendrite/templates/jobs.yaml +++ b/helm/dendrite/templates/jobs.yaml @@ -8,6 +8,7 @@ metadata: name: {{ $name }} labels: app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -80,7 +81,7 @@ spec: name: signing-key readOnly: true - name: generate-key - {{- include "image.name" $.Values.image | nindent 8 }} + {{- include "image.name" . | nindent 8 }} command: - sh - -c diff --git a/helm/dendrite/templates/prometheus-rules.yaml b/helm/dendrite/templates/prometheus-rules.yaml new file mode 100644 index 000000000..6693a4ed9 --- /dev/null +++ b/helm/dendrite/templates/prometheus-rules.yaml @@ -0,0 +1,16 @@ +{{- if and ( .Values.prometheus.rules.enabled ) ( .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" ) }} +--- +apiVersion: monitoring.coreos.com/v1 +kind: PrometheusRule +metadata: + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + {{- toYaml .Values.prometheus.rules.labels | nindent 4 }} +spec: + groups: + {{- if .Values.prometheus.rules.additionalRules }} + - name: {{ template "dendrite.name" . }}-Additional + rules: {{- toYaml .Values.prometheus.rules.additionalRules | nindent 4 }} + {{- end }} +{{- end }} diff --git a/helm/dendrite/templates/secrets.yaml b/helm/dendrite/templates/secrets.yaml index d4b8ecbf2..2084c9a56 100644 --- a/helm/dendrite/templates/secrets.yaml +++ b/helm/dendrite/templates/secrets.yaml @@ -1,15 +1,15 @@ -{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +{{- if (gt (len (.Files.Glob "appservices/*")) 0) }} --- apiVersion: v1 kind: Secret metadata: name: {{ include "dendrite.fullname" . }}-appservices-conf - namespace: {{ .Release.Namespace }} type: Opaque data: {{ (.Files.Glob "appservices/*").AsSecrets | indent 2 }} -{{ end }} -{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret) }} +{{- end }} + +{{- if and .Values.signing_key.create (not .Values.signing_key.existingSecret) }} --- apiVersion: v1 kind: Secret @@ -17,17 +17,29 @@ metadata: annotations: helm.sh/resource-policy: keep name: {{ include "dendrite.fullname" . }}-signing-key - namespace: {{ .Release.Namespace }} type: Opaque -{{ end }} +{{- end }} + +{{- with .Values.dendrite_config.global.metrics }} +{{- if .enabled }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" $ }}-metrics-basic-auth +type: Opaque +stringData: + user: {{ .basic_auth.user | quote }} + password: {{ .basic_auth.password | quote }} +{{- end }} +{{- end }} --- apiVersion: v1 kind: Secret -type: Opaque metadata: name: {{ include "dendrite.fullname" . }}-conf - namespace: {{ .Release.Namespace }} +type: Opaque stringData: dendrite.yaml: | {{ toYaml ( mustMergeOverwrite .Values.dendrite_config ( fromYaml (include "override.config" .) ) .Values.dendrite_config ) | nindent 4 }} \ No newline at end of file diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml index 365a43f04..3b571df1f 100644 --- a/helm/dendrite/templates/service.yaml +++ b/helm/dendrite/templates/service.yaml @@ -13,5 +13,5 @@ spec: ports: - name: http protocol: TCP - port: 8008 + port: {{ .Values.service.port }} targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/templates/servicemonitor.yaml b/helm/dendrite/templates/servicemonitor.yaml new file mode 100644 index 000000000..3819c7d02 --- /dev/null +++ b/helm/dendrite/templates/servicemonitor.yaml @@ -0,0 +1,26 @@ +{{- if and + (and .Values.prometheus.servicemonitor.enabled .Values.dendrite_config.global.metrics.enabled ) + ( .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" ) +}} +--- +apiVersion: monitoring.coreos.com/v1 +kind: ServiceMonitor +metadata: + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + {{- toYaml .Values.prometheus.servicemonitor.labels | nindent 4 }} +spec: + endpoints: + - port: http + basicAuth: + username: + name: {{ include "dendrite.fullname" . }}-metrics-basic-auth + key: "user" + password: + name: {{ include "dendrite.fullname" . }}-metrics-basic-auth + key: "password" + selector: + matchLabels: + {{- include "dendrite.selectorLabels" . | nindent 6 }} +{{- end }} diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 2c6e80942..c219d27f8 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -1,8 +1,10 @@ image: # -- Docker repository/image to use - name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent + # -- Overrides the image tag whose default is the chart appVersion. + tag: "" # signing key to use @@ -345,4 +347,27 @@ ingress: service: type: ClusterIP - port: 80 + port: 8008 + +prometheus: + servicemonitor: + # -- Enable ServiceMonitor for Prometheus-Operator for scrape metric-endpoint + enabled: false + # -- Extra Labels on ServiceMonitor for selector of Prometheus Instance + labels: {} + rules: + # -- Enable PrometheusRules for Prometheus-Operator for setup alerting + enabled: false + # -- Extra Labels on PrometheusRules for selector of Prometheus Instance + labels: {} + # -- additional alertrules (no default alertrules are provided) + additionalRules: [] + +grafana: + dashboards: + enabled: false + # -- Extra Labels on ConfigMap for selector of grafana sidecar + labels: + grafana_dashboard: "1" + # -- Extra Annotations on ConfigMap additional config in grafana sidecar + annotations: {} diff --git a/internal/caching/cache_eventstatekeys.go b/internal/caching/cache_eventstatekeys.go index 05580ab05..51e2499d5 100644 --- a/internal/caching/cache_eventstatekeys.go +++ b/internal/caching/cache_eventstatekeys.go @@ -7,6 +7,7 @@ import "github.com/matrix-org/dendrite/roomserver/types" type EventStateKeyCache interface { GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) + GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool) } func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) { @@ -15,4 +16,23 @@ func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (strin func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) { c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey) + c.RoomServerStateKeyNIDs.Set(eventStateKey, eventStateKeyNID) +} + +func (c Caches) GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool) { + return c.RoomServerStateKeyNIDs.Get(eventStateKey) +} + +type EventTypeCache interface { + GetEventTypeKey(eventType string) (types.EventTypeNID, bool) + StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string) +} + +func (c Caches) StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string) { + c.RoomServerEventTypeNIDs.Set(eventType, eventTypeNID) + c.RoomServerEventTypes.Set(eventTypeNID, eventType) +} + +func (c Caches) GetEventTypeKey(eventType string) (types.EventTypeNID, bool) { + return c.RoomServerEventTypeNIDs.Get(eventType) } diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 88a5b28bc..734a3a04f 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -9,19 +9,28 @@ type RoomServerCaches interface { RoomVersionCache RoomServerEventsCache EventStateKeyCache + EventTypeCache } // RoomServerNIDsCache contains the subset of functions needed for // a roomserver NID cache. type RoomServerNIDsCache interface { GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) + // StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) + GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) } func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { return c.RoomServerRoomIDs.Get(roomNID) } +// StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { + c.RoomServerRoomNIDs.Set(roomID, roomNID) c.RoomServerRoomIDs.Set(roomNID, roomID) } + +func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) { + return c.RoomServerRoomNIDs.Get(roomID) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 78c9ab7ee..479920466 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -23,16 +23,19 @@ import ( // different implementations as long as they satisfy the Cache // interface. type Caches struct { - RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version - ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys - RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID - RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID - RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event - RoomServerStateKeys Cache[types.EventStateKeyNID, string] // event NID -> event state key - FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU - FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU - SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response - LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID + RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version + ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys + RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID + RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID + RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event + RoomServerStateKeys Cache[types.EventStateKeyNID, string] // eventStateKey NID -> event state key + RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID + RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID + RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType + FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU + FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU + SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response + LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index 49292d0dc..fca93afd1 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -40,6 +40,9 @@ const ( spaceSummaryRoomsCache lazyLoadingCache eventStateKeyCache + eventTypeCache + eventTypeNIDCache + eventStateKeyNIDCache ) func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches { @@ -105,6 +108,21 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm Prefix: eventStateKeyCache, MaxAge: maxAge, }, + RoomServerStateKeyNIDs: &RistrettoCachePartition[string, types.EventStateKeyNID]{ // eventStateKey -> eventStateKey NID + cache: cache, + Prefix: eventStateKeyNIDCache, + MaxAge: maxAge, + }, + RoomServerEventTypeNIDs: &RistrettoCachePartition[string, types.EventTypeNID]{ // eventType -> eventType NID + cache: cache, + Prefix: eventTypeCache, + MaxAge: maxAge, + }, + RoomServerEventTypes: &RistrettoCachePartition[types.EventTypeNID, string]{ // eventType NID -> eventType + cache: cache, + Prefix: eventTypeNIDCache, + MaxAge: maxAge, + }, FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU &RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ cache: cache, diff --git a/internal/httputil/http.go b/internal/httputil/http.go deleted file mode 100644 index ad26de512..000000000 --- a/internal/httputil/http.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httputil - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" -) - -// PostJSON performs a POST request with JSON on an internal HTTP API. -// The error will match the errtype if returned from the remote API, or -// will be a different type if there was a problem reaching the API. -func PostJSON[reqtype, restype any, errtype error]( - ctx context.Context, span opentracing.Span, httpClient *http.Client, - apiURL string, request *reqtype, response *restype, -) error { - jsonBytes, err := json.Marshal(request) - if err != nil { - return err - } - - parsedAPIURL, err := url.Parse(apiURL) - if err != nil { - return err - } - - parsedAPIURL.Path = InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") - apiURL = parsedAPIURL.String() - - req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes)) - if err != nil { - return err - } - - // Mark the span as being an RPC client. - ext.SpanKindRPCClient.Set(span) - carrier := opentracing.HTTPHeadersCarrier(req.Header) - tracer := opentracing.GlobalTracer() - - if err = tracer.Inject(span.Context(), opentracing.HTTPHeaders, carrier); err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - - res, err := httpClient.Do(req.WithContext(ctx)) - if res != nil { - defer (func() { err = res.Body.Close() })() - } - if err != nil { - return err - } - var body []byte - body, err = io.ReadAll(res.Body) - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - if len(body) == 0 { - return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL) - } - var reserr errtype - if err = json.Unmarshal(body, &reserr); err != nil { - return fmt.Errorf("HTTP %d from %s - %w", res.StatusCode, apiURL, err) - } - return reserr - } - if err = json.Unmarshal(body, response); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - return nil -} diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go deleted file mode 100644 index 22f436e38..000000000 --- a/internal/httputil/internalapi.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httputil - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "reflect" - - "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" -) - -type InternalAPIError struct { - Type string - Message string -} - -func (e InternalAPIError) Error() string { - return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) -} - -func MakeInternalRPCAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype, *restype) error) http.Handler { - return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { - var request reqtype - var response restype - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := f(req.Context(), &request, &response); err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: &InternalAPIError{ - Type: reflect.TypeOf(err).String(), - Message: fmt.Sprintf("%s", err), - }, - } - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: &response, - } - }) -} - -func MakeInternalProxyAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype) (*restype, error)) http.Handler { - return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { - var request reqtype - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - response, err := f(req.Context(), &request) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: err, - } - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: response, - } - }) -} - -func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error { - span, ctx := opentracing.StartSpanFromContext(ctx, name) - defer span.Finish() - - return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response) -} - -func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, name) - defer span.Finish() - - var response restype - return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response) -} diff --git a/internal/httputil/paths.go b/internal/httputil/paths.go index 12cf59eb4..d06875428 100644 --- a/internal/httputil/paths.go +++ b/internal/httputil/paths.go @@ -19,8 +19,8 @@ const ( PublicFederationPathPrefix = "/_matrix/federation/" PublicKeyPathPrefix = "/_matrix/key/" PublicMediaPathPrefix = "/_matrix/media/" + PublicStaticPath = "/_matrix/static/" PublicWellKnownPrefix = "/.well-known/matrix/" - InternalPathPrefix = "/api/" DendriteAdminPathPrefix = "/_dendrite/" SynapseAdminPathPrefix = "/_synapse/" ) diff --git a/internal/log.go b/internal/log.go index d7e852c81..8fe98f20c 100644 --- a/internal/log.go +++ b/internal/log.go @@ -24,6 +24,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "github.com/matrix-org/util" @@ -37,6 +38,7 @@ import ( // this unfortunately results in us adding the same hook multiple times. // This map ensures we only ever add one level hook. var stdLevelLogAdded = make(map[logrus.Level]bool) +var levelLogAddedMu = &sync.Mutex{} type utcFormatter struct { logrus.Formatter @@ -99,6 +101,8 @@ func SetupPprof() { // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ @@ -125,9 +129,9 @@ func checkFileHookParams(params map[string]interface{}) { } // Add a new FSHook to the logger. Each component will log in its own file -func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName string) { +func setupFileHook(hook config.LogrusHook, level logrus.Level) { dirPath := (hook.Params["path"]).(string) - fullPath := filepath.Join(dirPath, componentName+".log") + fullPath := filepath.Join(dirPath, "dendrite.log") if err := os.MkdirAll(path.Dir(fullPath), os.ModePerm); err != nil { logrus.Fatalf("Couldn't create directory %s: %q", path.Dir(fullPath), err) diff --git a/internal/log_unix.go b/internal/log_unix.go index b38e7c2e8..3f15063d1 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -31,7 +31,9 @@ import ( // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error -func SetupHookLogging(hooks []config.LogrusHook, componentName string) { +func SetupHookLogging(hooks []config.LogrusHook) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -48,10 +50,10 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { switch hook.Type { case "file": checkFileHookParams(hook.Params) - setupFileHook(hook, level, componentName) + setupFileHook(hook, level) case "syslog": checkSyslogHookParams(hook.Params) - setupSyslogHook(hook, level, componentName) + setupSyslogHook(hook, level) case "std": setupStdLogHook(level) default: @@ -92,8 +94,8 @@ func setupStdLogHook(level logrus.Level) { stdLevelLogAdded[level] = true } -func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { - syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, componentName) +func setupSyslogHook(hook config.LogrusHook, level logrus.Level) { + syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, "dendrite") if err == nil { logrus.AddHook(&logLevelHook{level, syslogHook}) } diff --git a/internal/log_windows.go b/internal/log_windows.go index 39562328c..e1f0098a1 100644 --- a/internal/log_windows.go +++ b/internal/log_windows.go @@ -22,7 +22,7 @@ import ( // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error -func SetupHookLogging(hooks []config.LogrusHook, componentName string) { +func SetupHookLogging(hooks []config.LogrusHook) { logrus.SetReportCaller(true) for _, hook := range hooks { // Check we received a proper logging level @@ -40,7 +40,7 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { switch hook.Type { case "file": checkFileHookParams(hook.Params) - setupFileHook(hook, level, componentName) + setupFileHook(hook, level) default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 19483b268..81c055edd 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -124,6 +124,11 @@ type QueryProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } +// ExecProvider defines the interface for querys used by RunLimitedVariablesExec. +type ExecProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement // SQLlite can handle. See https://www.sqlite.org/limits.html for more information. const SQLite3MaxVariables = 999 @@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide return nil } +// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + _, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error") + return err + } + start = start + n + } + return nil +} + // StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. type StatementList []struct { Statement **sql.Stmt diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go new file mode 100644 index 000000000..13b00af50 --- /dev/null +++ b/internal/transactionrequest.go @@ -0,0 +1,356 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/dendrite/roomserver/api" + syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" +) + +var ( + PDUCountTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_pdus", + Help: "Number of incoming PDUs from remote servers with labels for success", + }, + []string{"status"}, // 'success' or 'total' + ) + EDUCountTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_edus", + Help: "Number of incoming EDUs from remote servers", + }, + ) +) + +type TxnReq struct { + gomatrixserverlib.Transaction + rsAPI api.FederationRoomserverAPI + userAPI userAPI.FederationUserAPI + ourServerName gomatrixserverlib.ServerName + keys gomatrixserverlib.JSONVerifier + roomsMu *MutexByRoom + producer *producers.SyncAPIProducer + inboundPresenceEnabled bool +} + +func NewTxnReq( + rsAPI api.FederationRoomserverAPI, + userAPI userAPI.FederationUserAPI, + ourServerName gomatrixserverlib.ServerName, + keys gomatrixserverlib.JSONVerifier, + roomsMu *MutexByRoom, + producer *producers.SyncAPIProducer, + inboundPresenceEnabled bool, + pdus []json.RawMessage, + edus []gomatrixserverlib.EDU, + origin gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + destination gomatrixserverlib.ServerName, +) TxnReq { + t := TxnReq{ + rsAPI: rsAPI, + userAPI: userAPI, + ourServerName: ourServerName, + keys: keys, + roomsMu: roomsMu, + producer: producer, + inboundPresenceEnabled: inboundPresenceEnabled, + } + + t.PDUs = pdus + t.EDUs = edus + t.Origin = origin + t.TransactionID = transactionID + t.Destination = destination + + return t +} + +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if t.producer != nil { + t.processEDUs(ctx) + } + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } + + for _, pdu := range t.PDUs { + PDUCountTotal.WithLabelValues("total").Inc() + var header struct { + RoomID string `json:"room_id"` + } + if err := json.Unmarshal(pdu, &header); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue + } + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: "Forbidden by server ACLs", + } + continue + } + if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Destination, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + PDUCountTotal.WithLabelValues("success").Inc() + } + + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +// nolint:gocyclo +func (t *TxnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { + EDUCountTotal.Inc() + switch e.Type { + case gomatrixserverlib.MTyping: + // https://matrix.org/docs/spec/server_server/latest#typing-notifications + var typingPayload struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Typing bool `json:"typing"` + } + if err := json.Unmarshal(e.Content, &typingPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") + } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to JetStream") + } + } + } + case gomatrixserverlib.MDeviceListUpdate: + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]types.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to JetStream") + continue + } + } + } + case types.MSigningKeyUpdate: + if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + logrus.WithError(err).Errorf("Failed to process signing key update") + } + case gomatrixserverlib.MPresence: + if t.inboundPresenceEnabled { + if err := t.processPresence(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process presence update") + } + } + default: + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") + } + } +} + +// processPresence handles m.receipt events +func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { + payload := types.Presence{} + if err := json.Unmarshal(e.Content, &payload); err != nil { + return err + } + for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + presence, ok := syncTypes.PresenceFromString(content.Presence) + if !ok { + continue + } + if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { + return err + } + } + return nil +} + +// processReceiptEvent sends receipt events to JetStream +func (t *TxnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } + // store every event + for _, eventID := range eventIDs { + if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go new file mode 100644 index 000000000..8597ae24b --- /dev/null +++ b/internal/transactionrequest_test.go @@ -0,0 +1,820 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/producers" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + keyAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`) + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) +) + +type FakeRsAPI struct { + rsAPI.RoomserverInternalAPI + shouldFailQuery bool + bannedFromRoom bool + shouldEventsFail bool +} + +func (r *FakeRsAPI) QueryRoomVersionForRoom( + ctx context.Context, + req *rsAPI.QueryRoomVersionForRoomRequest, + res *rsAPI.QueryRoomVersionForRoomResponse, +) error { + if r.shouldFailQuery { + return fmt.Errorf("Failure") + } + res.RoomVersion = gomatrixserverlib.RoomVersionV10 + return nil +} + +func (r *FakeRsAPI) QueryServerBannedFromRoom( + ctx context.Context, + req *rsAPI.QueryServerBannedFromRoomRequest, + res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + if r.bannedFromRoom { + res.Banned = true + } else { + res.Banned = false + } + return nil +} + +func (r *FakeRsAPI) InputRoomEvents( + ctx context.Context, + req *rsAPI.InputRoomEventsRequest, + res *rsAPI.InputRoomEventsResponse, +) error { + if r.shouldEventsFail { + return fmt.Errorf("Failure") + } + return nil +} + +func TestEmptyTransactionRequest(t *testing.T) { + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDU(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUs(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestBadPDU(t *testing.T) { + pdu := json.RawMessage("{\"room_id\":\"asdf\"}") + pdu2 := json.RawMessage("\"roomid\":\"asdf\"") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUSendFail(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { + cfg := &config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + SingleDatabase: true, + }) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, + UserAPI: nil, + } + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.typing"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badUser := gomatrixserverlib.EDU{Type: "m.receipt"} + if badUser.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "johnkaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badDomain := gomatrixserverlib.EDU{Type: "m.receipt"} + if badDomain.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:bad.domain": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + edus := []gomatrixserverlib.EDU{badEDU, badUser, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.presence"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu}) + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testRoomserverAPI struct { + rsAPI.RoomserverInternalAPI + inputRoomEvents []rsAPI.InputRoomEvent + queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse + queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse + queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *rsAPI.InputRoomEventsRequest, + response *rsAPI.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *rsAPI.QueryLatestEventsAndStateRequest, + response *rsAPI.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *rsAPI.QueryStateAfterEventsRequest, + response *rsAPI.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *rsAPI.QueryEventsByIDRequest, + response *rsAPI.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query if a server is joined to a room +func (t *testRoomserverAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *rsAPI.QueryServerJoinedToRoomRequest, + response *rsAPI.QueryServerJoinedToRoomResponse, +) error { + response.RoomExists = true + response.IsInRoom = true + return nil +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *rsAPI.QueryRoomVersionForRoomRequest, + response *rsAPI.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom( + ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + res.Banned = false + return nil +} + +func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq { + t := NewTxnReq( + rsAPI, + nil, + "", + &test.NopJSONVerifier{}, + NewMutexByRoom(), + nil, + false, + pdus, + nil, + testOrigin, + gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())), + testDestination) + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return &t +} + +func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) { + res, err := txn.ProcessTransaction(context.Background()) + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} diff --git a/internal/version.go b/internal/version.go index 685237b9e..e2655cb63 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 10 - VersionPatch = 8 + VersionMinor = 11 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/README.md b/keyserver/README.md deleted file mode 100644 index fd9f37d27..000000000 --- a/keyserver/README.md +++ /dev/null @@ -1,19 +0,0 @@ -## Key Server - -This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID. - -Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API. - -### Internal APIs -- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s). -- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls. -- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers. -- A topic which emits identity keys every time there is a change (addition or deletion). - -### Endpoint mappings -- Client API maps `/keys/upload` to `PerformUploadKeys`. -- Client API maps `/keys/query` to `QueryKeys`. -- Client API maps `/keys/claim` to `PerformClaimKeys`. -- Federation API maps `/user/keys/query` to `QueryKeys`. -- Federation API maps `/user/keys/claim` to `PerformClaimKeys`. -- Sync API maps `/keys/changes` to consuming from the Kafka topic. diff --git a/keyserver/api/api.go b/keyserver/api/api.go deleted file mode 100644 index 14fced3e8..000000000 --- a/keyserver/api/api.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "bytes" - "context" - "encoding/json" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/types" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI interface { - SyncKeyAPI - ClientKeyAPI - FederationKeyAPI - UserKeyAPI - - // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.KeyserverUserAPI) -} - -// API functions required by the clientapi -type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -// API functions required by the userapi -type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error -} - -// API functions required by the syncapi -type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error -} - -// KeyError is returned if there was a problem performing/querying the server -type KeyError struct { - Err string `json:"error"` - IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE - IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM - IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM -} - -func (k *KeyError) Error() string { - return k.Err -} - -type DeviceMessageType int - -const ( - TypeDeviceKeyUpdate DeviceMessageType = iota - TypeCrossSigningUpdate -) - -// DeviceMessage represents the message produced into Kafka by the key server. -type DeviceMessage struct { - Type DeviceMessageType `json:"Type,omitempty"` - *DeviceKeys `json:"DeviceKeys,omitempty"` - *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` - // A monotonically increasing number which represents device changes for this user. - StreamID int64 - DeviceChangeID int64 -} - -// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log -type OutputCrossSigningKeyUpdate struct { - CrossSigningKeyUpdate `json:"signing_keys"` -} - -type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` -} - -// DeviceKeysEqual returns true if the device keys updates contain the -// same display name and key JSON. This will return false if either of -// the updates is not a device keys update, or if the user ID/device ID -// differ between the two. -func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { - if m1.DeviceKeys == nil || m2.DeviceKeys == nil { - return false - } - if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { - return false - } - if m1.DisplayName != m2.DisplayName { - return false // different display names - } - if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { - return false // either is empty - } - return bytes.Equal(m1.KeyJSON, m2.KeyJSON) -} - -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { - return DeviceMessage{ - DeviceKeys: k, - StreamID: streamID, - } -} - -// OneTimeKeys represents a set of one-time keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type OneTimeKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // A map of algorithm:key_id => key JSON - KeyJSON map[string]json.RawMessage -} - -// Split a key in KeyJSON into algorithm and key ID -func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { - segments := strings.Split(keyIDWithAlgo, ":") - return segments[0], segments[1] -} - -// OneTimeKeysCount represents the counts of one-time keys for a single device -type OneTimeKeysCount struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // algorithm to count e.g: - // { - // "curve25519": 10, - // "signed_curve25519": 20 - // } - KeyCount map[string]int -} - -// PerformUploadKeysRequest is the request to PerformUploadKeys -type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys - // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update - // the display name for their respective device, and NOT to modify the keys. The key - // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. - // Without this flag, requests to modify device display names would delete device keys. - OnlyDisplayNameUpdates bool -} - -// PerformUploadKeysResponse is the response to PerformUploadKeys -type PerformUploadKeysResponse struct { - // A fatal error when processing e.g database failures - Error *KeyError - // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount -} - -// PerformDeleteKeysRequest asks the keyserver to forget about certain -// keys, and signatures related to those keys. -type PerformDeleteKeysRequest struct { - UserID string - KeyIDs []gomatrixserverlib.KeyID -} - -// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. -type PerformDeleteKeysResponse struct { - Error *KeyError -} - -// KeyError sets a key error field on KeyErrors -func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { - if r.KeyErrors[userID] == nil { - r.KeyErrors[userID] = make(map[string]*KeyError) - } - r.KeyErrors[userID][deviceID] = err -} - -type PerformClaimKeysRequest struct { - // Map of user_id to device_id to algorithm name - OneTimeKeys map[string]map[string]string - Timeout time.Duration -} - -type PerformClaimKeysResponse struct { - // Map of user_id to device_id to algorithm:key_id to key JSON - OneTimeKeys map[string]map[string]map[string]json.RawMessage - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Set if there was a fatal error processing this action - Error *KeyError -} - -type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys - // The user that uploaded the key, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceKeysResponse struct { - Error *KeyError -} - -type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice - // The user that uploaded the sig, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceSignaturesResponse struct { - Error *KeyError -} - -type QueryKeysRequest struct { - // The user ID asking for the keys, e.g. if from a client API request. - // Will not be populated if the key request came from federation. - UserID string - // Maps user IDs to a list of devices - UserToDevices map[string][]string - Timeout time.Duration -} - -type QueryKeysResponse struct { - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage - // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // Set if there was a fatal error processing this query - Error *KeyError -} - -type QueryKeyChangesRequest struct { - // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning - Offset int64 - // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). - ToOffset int64 -} - -type QueryKeyChangesResponse struct { - // The set of users who have had their keys change. - UserIDs []string - // The latest offset represented in this response. - Offset int64 - // Set if there was a problem handling the request. - Error *KeyError -} - -type QueryOneTimeKeysRequest struct { - // The local user to query OTK counts for - UserID string - // The device to query OTK counts for - DeviceID string -} - -type QueryOneTimeKeysResponse struct { - // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError -} - -type QueryDeviceMessagesRequest struct { - UserID string -} - -type QueryDeviceMessagesResponse struct { - // The latest stream ID - StreamID int64 - Devices []DeviceMessage - Error *KeyError -} - -type QuerySignaturesRequest struct { - // A map of target user ID -> target key/device IDs to retrieve signatures for - TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` -} - -type QuerySignaturesResponse struct { - // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures - Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // The request error, if any - Error *KeyError -} - -type PerformMarkAsStaleRequest struct { - UserID string - Domain gomatrixserverlib.ServerName - DeviceID string -} diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go deleted file mode 100644 index 75d537d9c..000000000 --- a/keyserver/inthttp/client.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -// HTTP paths for the internal HTTP APIs -const ( - InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate" - PerformUploadKeysPath = "/keyserver/performUploadKeys" - PerformClaimKeysPath = "/keyserver/performClaimKeys" - PerformDeleteKeysPath = "/keyserver/performDeleteKeys" - PerformUploadDeviceKeysPath = "/keyserver/performUploadDeviceKeys" - PerformUploadDeviceSignaturesPath = "/keyserver/performUploadDeviceSignatures" - QueryKeysPath = "/keyserver/queryKeys" - QueryKeyChangesPath = "/keyserver/queryKeyChanges" - QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" - QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" - QuerySignaturesPath = "/keyserver/querySignatures" - PerformMarkAsStalePath = "/keyserver/markAsStale" -) - -// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewKeyServerClient( - apiURL string, - httpClient *http.Client, -) (api.KeyInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewKeyServerClient: httpClient is ") - } - return &httpKeyInternalAPI{ - apiURL: apiURL, - httpClient: httpClient, - }, nil -} - -type httpKeyInternalAPI struct { - apiURL string - httpClient *http.Client -} - -func (h *httpKeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - // no-op: doesn't need it -} - -func (h *httpKeyInternalAPI) PerformClaimKeys( - ctx context.Context, - request *api.PerformClaimKeysRequest, - response *api.PerformClaimKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformClaimKeys", h.apiURL+PerformClaimKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformDeleteKeys( - ctx context.Context, - request *api.PerformDeleteKeysRequest, - response *api.PerformDeleteKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadKeys( - ctx context.Context, - request *api.PerformUploadKeysRequest, - response *api.PerformUploadKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadKeys", h.apiURL+PerformUploadKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryKeys( - ctx context.Context, - request *api.QueryKeysRequest, - response *api.QueryKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeys", h.apiURL+QueryKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryOneTimeKeys( - ctx context.Context, - request *api.QueryOneTimeKeysRequest, - response *api.QueryOneTimeKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryDeviceMessages( - ctx context.Context, - request *api.QueryDeviceMessagesRequest, - response *api.QueryDeviceMessagesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryKeyChanges( - ctx context.Context, - request *api.QueryKeyChangesRequest, - response *api.QueryKeyChangesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeyChanges", h.apiURL+QueryKeyChangesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadDeviceKeys( - ctx context.Context, - request *api.PerformUploadDeviceKeysRequest, - response *api.PerformUploadDeviceKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures( - ctx context.Context, - request *api.PerformUploadDeviceSignaturesRequest, - response *api.PerformUploadDeviceSignaturesResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QuerySignatures( - ctx context.Context, - request *api.QuerySignaturesRequest, - response *api.QuerySignaturesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySignatures", h.apiURL+QuerySignaturesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformMarkAsStaleIfNeeded( - ctx context.Context, - request *api.PerformMarkAsStaleRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "MarkAsStale", h.apiURL+PerformMarkAsStalePath, - h.httpClient, ctx, request, response, - ) -} diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go deleted file mode 100644 index 443269f73..000000000 --- a/keyserver/inthttp/server.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" -) - -func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI, enableMetrics bool) { - internalAPIMux.Handle( - PerformClaimKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", enableMetrics, s.PerformClaimKeys), - ) - - internalAPIMux.Handle( - PerformDeleteKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", enableMetrics, s.PerformDeleteKeys), - ) - - internalAPIMux.Handle( - PerformUploadKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", enableMetrics, s.PerformUploadKeys), - ) - - internalAPIMux.Handle( - PerformUploadDeviceKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", enableMetrics, s.PerformUploadDeviceKeys), - ) - - internalAPIMux.Handle( - PerformUploadDeviceSignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", enableMetrics, s.PerformUploadDeviceSignatures), - ) - - internalAPIMux.Handle( - QueryKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeys", enableMetrics, s.QueryKeys), - ) - - internalAPIMux.Handle( - QueryOneTimeKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", enableMetrics, s.QueryOneTimeKeys), - ) - - internalAPIMux.Handle( - QueryDeviceMessagesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", enableMetrics, s.QueryDeviceMessages), - ) - - internalAPIMux.Handle( - QueryKeyChangesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", enableMetrics, s.QueryKeyChanges), - ) - - internalAPIMux.Handle( - QuerySignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", enableMetrics, s.QuerySignatures), - ) - - internalAPIMux.Handle( - PerformMarkAsStalePath, - httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", enableMetrics, s.PerformMarkAsStaleIfNeeded), - ) -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go deleted file mode 100644 index 275576773..000000000 --- a/keyserver/keyserver.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keyserver - -import ( - "github.com/gorilla/mux" - "github.com/sirupsen/logrus" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/inthttp" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" -) - -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetrics bool) { - inthttp.AddRoutes(router, intAPI, enableMetrics) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, - rsAPI rsapi.KeyserverRoomserverAPI, -) api.KeyInternalAPI { - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - db, err := storage.NewDatabase(base, &cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to key server database") - } - - keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), - JetStream: js, - DB: db, - } - ap := &internal.KeyInternalAPI{ - DB: db, - Cfg: cfg, - FedClient: fedClient, - Producer: keyChangeProducer, - } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable - ap.Updater = updater - - // Remove users which we don't share a room with anymore - if err := updater.CleanUp(); err != nil { - logrus.WithError(err).Error("failed to cleanup stale device lists") - } - - go func() { - if err := updater.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } - }() - - dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, - ) - if err := dlConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start device list consumer") - } - - sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, ap, - ) - if err := sigConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start signing key consumer") - } - - return ap -} diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go deleted file mode 100644 index 159b280f5..000000000 --- a/keyserver/keyserver_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package keyserver - -import ( - "context" - "testing" - - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -// Merely tests that we can create an internal keyserver API -func Test_NewInternalAPI(t *testing.T) { - rsAPI := &mockKeyserverRoomserverAPI{} - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - defer closeBase() - _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - }) -} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go deleted file mode 100644 index c6a8f44cd..000000000 --- a/keyserver/storage/interface.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination - // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. - ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - - // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - - // OneTimeKeysCount returns a count of all OTKs for this device. - OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). - // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. - // Returns an error if there was a problem storing the keys. - StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. - // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - - // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying - // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error - - // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key - // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. - ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of types.OffsetNewest means no upper limit. - // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) - CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - - DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, - ) error -} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go deleted file mode 100644 index 35e630559..000000000 --- a/keyserver/storage/postgres/storage.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) - if err != nil { - return nil, err - } - otk, err := NewPostgresOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewPostgresDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewPostgresKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewPostgresStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewPostgresCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewPostgresCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go deleted file mode 100644 index 54dd6ddc9..000000000 --- a/keyserver/storage/shared/storage.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges - StaleDeviceListsTable tables.StaleDeviceLists - CrossSigningKeysTable tables.CrossSigningKeys - CrossSigningSigsTable tables.CrossSigningSigs -} - -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) -} - -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return err - }) - return -} - -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) -} - -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) -} - -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) - if err != nil { - return false, err - } - return count == len(prevIDs), nil -} - -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, userID := range clearUserIDs { - err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) - if err != nil { - return err - } - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { - // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int64) - for _, k := range keys { - userIDToStreamID[k.UserID] = 0 - } - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID := range userIDToStreamID { - streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) - if err != nil { - return err - } - userIDToStreamID[userID] = streamID - } - // set the stream IDs for each key - for i := range keys { - k := keys[i] - userIDToStreamID[k.UserID]++ // start stream from 1 - k.StreamID = userIDToStreamID[k.UserID] - keys[i] = k - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) -} - -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { - var result []api.OneTimeKeys - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID, deviceToAlgo := range userToDeviceToAlgorithm { - for deviceID, algo := range deviceToAlgo { - keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) - if err != nil { - return err - } - if keyJSON != nil { - result = append(result, api.OneTimeKeys{ - UserID: userID, - DeviceID: deviceID, - KeyJSON: keyJSON, - }) - } - } - } - return nil - }) - return result, err -} - -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { - err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) - return err - }) - return -} - -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) - }) -} - -// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying -// cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, deviceID := range deviceIDs { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) - } - if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) - } - } - return nil - }) -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { - keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) - if err != nil { - return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) - } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: key, - }, - } - sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) - if err != nil { - continue - } - for sigUserID, forSigUserID := range sigMap { - if userID != sigUserID { - continue - } - if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sigKeyID, sigBytes := range forSigUserID { - result.Signatures[sigUserID][sigKeyID] = sigBytes - } - } - results[purpose] = result - } - return results, nil -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) -} - -// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { - return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) -} - -// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { - return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) - } - } - return nil - }) -} - -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( - ctx context.Context, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) - } - return nil - }) -} - -// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. -func (d *Database) DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) - }) -} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go deleted file mode 100644 index 873fe3e24..000000000 --- a/keyserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) - if err != nil { - return nil, err - } - otk, err := NewSqliteOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewSqliteDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewSqliteCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewSqliteCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go deleted file mode 100644 index e7a2af7c2..000000000 --- a/keyserver/storage/storage_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package storage_test - -import ( - "context" - "reflect" - "sync" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ctx = context.Background() - -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) - if err != nil { - t.Fatalf("failed to create new database: %v", err) - } - return db, close -} - -func MustNotError(t *testing.T, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("operation failed: %s", err) -} - -func TestKeyChanges(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesNoDupes(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesUpperLimit(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -var dbLock sync.Mutex -var deviceArray = []string{"AAA", "another_device"} - -// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, -// and that they are returned correctly when querying for device keys. -func TestDeviceKeysStreamIDGeneration(t *testing.T) { - var err error - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - dbLock.Lock() - defer dbLock.Unlock() - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) - - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) - } - } - }) -} diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go deleted file mode 100644 index 75c9053e8..000000000 --- a/keyserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go deleted file mode 100644 index 24da1125e..000000000 --- a/keyserver/storage/tables/interface.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type OneTimeKeys interface { - SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. - // Returns an empty map if the key does not exist. - SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error -} - -type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) - CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error - DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error -} - -type KeyChanges interface { - InsertKeyChange(ctx context.Context, userID string) (int64, error) - // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - Prepare() error -} - -type StaleDeviceLists interface { - InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error -} - -type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error -} - -type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error -} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 9dcfa955f..50af2f884 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -108,13 +108,16 @@ func makeDownloadAPI( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { - counterVec := promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + var counterVec *prometheus.CounterVec + if cfg.Matrix.Metrics.Enabled { + counterVec = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: name, + Help: "Total number of media_api requests for either thumbnails or full downloads", + }, + []string{"code"}, + ) + } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -166,5 +169,12 @@ func makeDownloadAPI( vars["downloadName"], ) } - return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + + var handlerFunc http.HandlerFunc + if counterVec != nil { + handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + } else { + handlerFunc = http.HandlerFunc(httpHandler) + } + return handlerFunc } diff --git a/relayapi/ARCHITECTURE.md b/relayapi/ARCHITECTURE.md new file mode 100644 index 000000000..f4cf529b4 --- /dev/null +++ b/relayapi/ARCHITECTURE.md @@ -0,0 +1,134 @@ +## Relay Server Architecture + +Relay Servers function similar to the way physical mail drop boxes do. +A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server. +Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination. +Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means. + +Currently events are sent as complete Matrix Transactions. +Transactions include a list of PDUs, (which contain, among other things, lists of authorization events, previous events, and signatures) a list of EDUs, and other information about the transaction. +There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today. +In the future this will probably need to change in order to handle more complex room state resolution during p2p usage. + +### Design + +``` + 0 +--------------------+ + +----------------------------------------+ | P2P Node A | + | Relay Server | | +--------+ | + | | | | Client | | + | +--------------------+ | | +--------+ | + | | Relay Server API | | | | | + | | | | | V | + | .--------. 2 | +-------------+ | | 1 | +------------+ | + | |`--------`| <----- | Forwarder | <------------- | Homeserver | | + | | Database | | +-------------+ | | | +------------+ | + | `----------` | | | +--------------------+ + | ^ | | | + | | 4 | +-------------+ | | + | `------------ | Retriever | <------. +--------------------+ + | | +-------------+ | | | | P2P Node B | + | | | | | | +--------+ | + | +--------------------+ | | | | Client | | + | | | | +--------+ | + +----------------------------------------+ | | | | + | | V | + 3 | | +------------+ | + `------ | Homeserver | | + | +------------+ | + +--------------------+ +``` + +- 0: This relay server is currently only acting on behalf of `P2P Node B`. It will only receive, and later forward events that are destined for `P2P Node B`. +- 1: When `P2P Node A` fails sending directly to `P2P Node B` (after a configurable number of attempts), it checks for any known relay servers associated with `P2P Node B` and sends to all of them. + - If sending to any of the relay servers succeeds, that transaction is considered to be successfully sent. +- 2: The relay server `forwarder` stores the transaction json in its database and marks it as destined for `P2P Node B`. +- 3: When `P2P Node B` comes online, it queries all its relay servers for any missed messages. +- 4: The relay server `retriever` will look in its database for any transactions that are destined for `P2P Node B` and returns them one at a time. + +For now, it is important that we don’t design out a hybrid approach of having both sender-side and recipient-side relay servers. +Both approaches make sense and determining which makes for a better experience depends on the use case. + +#### Sender-Side Relay Servers + +If we are running around truly ad-hoc, and I don't know when or where you will be able to pick up messages, then having a sender designated server makes sense to give things the best chance at making their way to the destination. +But in order to achieve this, you are either relying on p2p presence broadcasts for the relay to know when to try forwarding (which means you are in a pretty small network), or the relay just keeps on periodically attempting to forward to the destination which will lead to a lot of extra traffic on the network. + +#### Recipient-Side Relay Servers + +If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it. + +### API + +Relay servers make use of 2 new matrix federation endpoints. +These are: +- PUT /_matrix/federation/v1/send_relay/{txnID}/{userID} +- GET /_matrix/federation/v1/relay_txn/{userID} + +#### Send_Relay + +The `send_relay` endpoint is used to send events to a relay server that are destined for some other node. Servers can send events to this endpoint if they wish for the relay server to store & forward events for them when they go offline. + +##### Request + +###### Request Parameters + +| Name | Type | Description | +|--------|--------|-----------------------------------------------------| +| txnID | string | **Required:** The transaction ID. | +| userID | string | **Required:** The destination for this transaciton. | + +###### Request Body + +| Name | Type | Description | +|--------|--------|----------------------------------------| +| pdus | [PDU] | **Required:** List of pdus. Max 50. | +| edus | [EDU] | List of edus. May be omitted. Max 100. | + +##### Responses + +| Code | Reason | +|--------|--------------------------------------------------| +| 200 | Successfully stored transaction for forwarding. | +| 400 | Invalid userID. | +| 400 | Invalid request body. | +| 400 | Too many pdus or edus. | +| 500 | Server failed processing transaction. | + +#### Relay_Txn + +The `relay_txn` endpoint is used to get events from a relay server that are destined for you. Servers can send events to this endpoint if they wish for the relay server to store & forward events for them when they go offline. + +##### Request + +**This needs to be changed to prevent nodes from obtaining transactions not destined for them. Possibly by adding a signature field to the request.** + +###### Request Parameters + +| Name | Type | Description | +|--------|--------|----------------------------------------------------------------| +| userID | string | **Required:** The user ID that events are being requested for. | + +###### Request Body + +| Name | Type | Description | +|----------|--------|----------------------------------------| +| entry_id | int64 | **Required:** The id of the previous transaction received from the relay. Provided in the previous response to this endpoint. | + +##### Responses + +| Code | Reason | +|--------|--------------------------------------------------| +| 200 | Successfully stored transaction for forwarding. | +| 400 | Invalid userID. | +| 400 | Invalid request body. | +| 400 | Invalid previous entry. Must be >= 0 | +| 500 | Server failed processing transaction. | + +###### 200 Response Body + +| Name | Type | Description | +|----------------|--------------|--------------------------------------------------------------------------| +| transaction | Transaction | **Required:** A matrix transaction. | +| entry_id | int64 | An ID associated with this transaction. | +| entries_queued | bool | **Required:** Whether or not there are more events stored for this user. | diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 000000000..9b4b62e58 --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,62 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + // Retrieve from external relay server all transactions stored for us and process them. + PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, + ) error + + // Tells the relayapi whether or not it should act as a relay server for external servers. + SetRelayingEnabled(bool) + + // Obtain whether the relayapi is currently configured to act as a relay server for external servers. + RelayingEnabled() bool +} + +// RelayServerAPI exposes the store & query transaction functionality of a relay server. +type RelayServerAPI interface { + // Store transactions for forwarding to the destination at a later time. + PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, + ) (QueryRelayTransactionsResponse, error) +} + +type QueryRelayTransactionsResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:"entries_queued"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 000000000..3a5762fb8 --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,59 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "sync" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient fedAPI.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName + relayingEnabledMutex sync.Mutex + relayingEnabled bool +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient fedAPI.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, + relayingEnabled bool, +) *RelayInternalAPI { + return &RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + relayingEnabled: relayingEnabled, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 000000000..62c7d446e --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,155 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// SetRelayingEnabled implements api.RelayInternalAPI +func (r *RelayInternalAPI) SetRelayingEnabled(enabled bool) { + r.relayingEnabledMutex.Lock() + defer r.relayingEnabledMutex.Unlock() + r.relayingEnabled = enabled +} + +// RelayingEnabled implements api.RelayInternalAPI +func (r *RelayInternalAPI) RelayingEnabled() bool { + r.relayingEnabledMutex.Lock() + defer r.relayingEnabledMutex.Unlock() + return r.relayingEnabled +} + +// PerformRelayServerSync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any + // transactions available for this node. + prevEntry := gomatrixserverlib.RelayEntry{} + asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Transaction) + + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + for asyncResponse.EntriesQueued { + // There are still more entries available for this node from the relay. + logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) + asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Transaction) + } + + return nil +} + +// PerformStoreTransaction implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, +) error { + logrus.Warnf("Storing transaction for %v", userID) + receipt, err := r.db.StoreTransaction(ctx, transaction) + if err != nil { + logrus.Errorf("db.StoreTransaction: %s", err.Error()) + return err + } + err = r.db.AssociateTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + userID: {}, + }, + transaction.TransactionID, + receipt) + + return err +} + +// QueryTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, +) (api.QueryRelayTransactionsResponse, error) { + logrus.Infof("QueryTransactions for %s", userID.Raw()) + if previousEntry.EntryID > 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + previousEntry.EntryID, + userID.Raw(), + ) + prevReceipt := receipt.NewReceipt(previousEntry.EntryID) + err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) + if err != nil { + logrus.Errorf("db.CleanTransactions: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + } + + transaction, receipt, err := r.db.GetTransaction(ctx, userID) + if err != nil { + logrus.Errorf("db.GetTransaction: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + + response := api.QueryRelayTransactionsResponse{} + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) + response.Transaction = *transaction + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", userID.Raw()) + response.EntryID = 0 + response.EntriesQueued = false + } + + return response, nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go new file mode 100644 index 000000000..278706a3e --- /dev/null +++ b/relayapi/internal/perform_test.go @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + fedAPI.FederationClient + shouldFail bool + queryCount uint + queueDepth uint +} + +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + f.queryCount++ + if f.shouldFail { + return res, fmt.Errorf("Error") + } + + res = gomatrixserverlib.RespGetRelayTransaction{ + Transaction: gomatrixserverlib.Transaction{}, + EntryID: 0, + } + if f.queueDepth > 0 { + res.EntriesQueued = true + } else { + res.EntriesQueued = false + } + f.queueDepth -= 1 + + return +} + +func TestPerformRelayServerSync(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", true, + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) +} + +func TestPerformRelayServerSyncFedError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{shouldFail: true} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", true, + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.Error(t, err) +} + +func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{queueDepth: 2} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", true, + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) + assert.Equal(t, uint(3), fedClient.queryCount) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go new file mode 100644 index 000000000..200a1814a --- /dev/null +++ b/relayapi/relayapi.go @@ -0,0 +1,76 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + relayAPI api.RelayInternalAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + relayingEnabled bool, +) api.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI + + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + if err != nil { + logrus.WithError(err).Panic("failed to connect to relay db") + } + + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + relayingEnabled, + ) +} diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go new file mode 100644 index 000000000..f1b3262aa --- /dev/null +++ b/relayapi/relayapi_test.go @@ -0,0 +1,192 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi_test + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/relayapi" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestCreateNewRelayInternalAPI(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + assert.NotNil(t, relayAPI) + }) +} + +func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + if dbType == test.DBTypeSQLite { + base.Cfg.RelayAPI.Database.ConnectionString = "file:" + } else { + base.Cfg.RelayAPI.Database.ConnectionString = "test" + } + defer close() + + assert.Panics(t, func() { + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + }) + }) +} + +func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + assert.Panics(t, func() { + relayapi.AddPublicRoutes(base, nil, nil) + }) + }) +} + +func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := gomatrixserverlib.RelayEntry{EntryID: 0} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +type sendRelayContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} + +func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + content := sendRelayContent{} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID, "txnID": txnID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +func TestCreateRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + }{ + { + name: "relay_txn invalid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + wantCode: 400, + }, + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 200, + }, + { + name: "send_relay invalid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + wantCode: 400, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 200, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} + +func TestDisableRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, false) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + }{ + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 404, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 404, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go new file mode 100644 index 000000000..63b42ec7d --- /dev/null +++ b/relayapi/routing/relaytxn.go @@ -0,0 +1,68 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// GetTransactionFromRelay implements GET /_matrix/federation/v1/relay_txn/{userID} +// This endpoint can be extracted into a separate relay server service. +func GetTransactionFromRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Processing relay_txn for %s", userID.Raw()) + + var previousEntry gomatrixserverlib.RelayEntry + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } + } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + + response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: gomatrixserverlib.RespGetRelayTransaction{ + Transaction: response.Transaction, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, + }, + } +} diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go new file mode 100644 index 000000000..4c099a642 --- /dev/null +++ b/relayapi/routing/relaytxn_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func createQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request.SetContent(prevEntry) + + return request +} + +func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestGetReturnsSavedTransaction(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetReturnsMultipleSavedTransactions(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + transaction2 := createTransaction() + receipt2, err := db.StoreTransaction(context.Background(), transaction2) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 000000000..8ee0743ef --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,136 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( + "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + logrus.Infof("Handling send_relay from: %s", request.Origin()) + if !relayAPI.RelayingEnabled() { + logrus.Warnf("Dropping send_relay from: %s", request.Origin()) + return util.JSONResponse{ + Code: http.StatusNotFound, + } + } + + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return SendTransactionToRelay( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( + "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + logrus.Infof("Handling relay_txn from: %s", request.Origin()) + if !relayAPI.RelayingEnabled() { + logrus.Warnf("Dropping relay_txn from: %s", request.Origin()) + return util.JSONResponse{ + Code: http.StatusNotFound, + } + } + + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go new file mode 100644 index 000000000..ce744cb49 --- /dev/null +++ b/relayapi/routing/sendrelay.go @@ -0,0 +1,75 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// SendTransactionToRelay implements PUT /_matrix/federation/v1/send_relay/{txnID}/{userID} +// This endpoint can be extracted into a separate relay server service. +func SendTransactionToRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + txnID gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Processing send_relay for %s", userID.Raw()) + + var txnEvents gomatrixserverlib.RelayEvents + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + } + } + + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + return util.JSONResponse{Code: 200} +} diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go new file mode 100644 index 000000000..66594c47c --- /dev/null +++ b/relayapi/routing/sendrelay_test.go @@ -0,0 +1,209 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + return txn +} + +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) + + return request +} + +func TestForwardEmptyReturnsOk(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", true, + ) + + response := routing.SendTransactionToRelay( + httpReq, &request, relayAPI, txn.TransactionID, *userID) + transaction, _, err := db.GetTransaction(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction") + + transactionCount, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 000000000..f5f9a06e5 --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + // Adds a new transaction to the queue json table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + + // Adds a new transaction_id: server_name mapping with associated json table nid to the queue + // entry table for each provided destination. + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + + // Removes every server_name: receipt pair provided from the queue entries table. + // Will then remove every entry for each receipt provided from the queue json table. + // If any of the entries don't exist in either table, nothing will happen for that entry and + // an error will not be generated. + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + + // Gets the oldest transaction for the provided server_name. + // If no transactions exist, returns nil and no error. + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + + // Gets the number of transactions being stored for the provided server_name. + // If the server doesn't exist in the database then 0 is returned with no error. + GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go new file mode 100644 index 000000000..74410fc88 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid = ANY($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + {&s.deleteJSONStmt, deleteQueueJSONSQL}, + {&s.selectJSONStmt, selectQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go new file mode 100644 index 000000000..e97cf8cc0 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -0,0 +1,156 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt +} + +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 000000000..1042beba7 --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 000000000..0993707bf --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { + var err error + jsonTransaction, err := json.Marshal(transaction) + if err != nil { + return nil, fmt.Errorf("failed to marshal: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil +} + +func (d *Database) AssociateTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + dbReceipt *receipt.Receipt, +) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var lastErr error + for destination := range destinations { + destination := destination + err := d.RelayQueue.InsertQueueEntry( + ctx, + txn, + transactionID, + destination.Domain(), + dbReceipt.GetNID(), + ) + if err != nil { + lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + return lastErr + }) + + return err +} + +func (d *Database) CleanTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*receipt.Receipt, +) error { + nids := make([]int64, len(receipts)) + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() + } + + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + // TODO : If there are still queue entries for any of these nids for other destinations + // then we shouldn't delete the json entries. + // But this can't happen with the current api design. + // There will only ever be one server entry for each nid since each call to send_relay + // only accepts a single server name and inside there we create a new json entry. + // So for multiple destinations we would call send_relay multiple times and have multiple + // json entries of the same transaction. + // + // TLDR; this works as expected right now but can easily be optimised in the future. + deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids) + + if deleteEntryErr != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr) + } + if deleteJSONErr != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr) + } + return nil + }) + + return err +} + +func (d *Database) GetTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { + entriesRequested := 1 + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + firstNID := nids[0] + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + if _, ok := txns[firstNID]; !ok { + return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID) + } + + err = json.Unmarshal(txns[firstNID], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) + } + + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil +} + +func (d *Database) GetTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go new file mode 100644 index 000000000..502da3b00 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid IN ($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go new file mode 100644 index 000000000..49c6b4de5 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..3ed4ab046 --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/keyserver/storage/storage.go b/relayapi/storage/storage.go similarity index 58% rename from keyserver/storage/storage.go rename to relayapi/storage/storage.go index ab6a35401..16ecbcfb7 100644 --- a/keyserver/storage/storage.go +++ b/relayapi/storage/storage.go @@ -1,4 +1,4 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,20 +20,26 @@ package storage import ( "fmt" - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 000000000..9056a5678 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. +// These are the transactions being stored for the given destination server. +// The nids correspond to entries in the RelayQueueJSON table. +type RelayQueue interface { + // Adds a new transaction_id: server_name mapping with associated json table nid to the table. + // Will ensure only one transaction id is present for each server_name: nid mapping. + // Adding duplicates will silently do nothing. + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + + // Removes multiple entries from the table corresponding the the list of nids provided. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + + // Get a list of nids associated with the provided server name. + // Returns up to `limit` nids. The entries are returned oldest first. + // Will return an empty list if no matches were found. + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + + // Get the number of entries in the table associated with the provided server name. + // If there are no matching rows, a count of 0 is returned with err set to nil. + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +// RelayQueueJSON table contains a map of nid to the raw transaction json. +type RelayQueueJSON interface { + // Adds a new transaction to the table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + + // Removes multiple nids from the table. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + + // Get the transaction json corresponding to the provided nids. + // Will return a partial result containing any matching nid from the table. + // Will return an empty map if no matches were found. + // It is the caller's responsibility to deal with the results appropriately. + // return: map indexed by nid of each matching transaction json. + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go new file mode 100644 index 000000000..efa3363e5 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -0,0 +1,173 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func mustCreateTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + + return txn +} + +type RelayQueueJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueueJSON +} + +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueueJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 1, len(storedJSON)) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 0, len(storedJSON)) + }) +} diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go new file mode 100644 index 000000000..99f9922c0 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -0,0 +1,229 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type RelayQueueDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueue +} + +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueue + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + }) +} + +func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(2) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName = gomatrixserverlib.ServerName("domain") + oldestNID := int64(1) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + + retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, nid, retrievedNids[1]) + assert.Equal(t, 2, len(retrievedNids)) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(0), count) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + }) +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 420ef278a..73732ae32 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,7 +17,6 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI - KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -151,6 +150,7 @@ type ClientRoomserverAPI interface { PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error @@ -166,6 +166,7 @@ type ClientRoomserverAPI interface { type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI + KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go deleted file mode 100644 index b23263d17..000000000 --- a/roomserver/api/api_trace.go +++ /dev/null @@ -1,417 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsAPI "github.com/matrix-org/dendrite/federationapi/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the -// complete request/response/error -type RoomserverInternalAPITrace struct { - Impl RoomserverInternalAPI -} - -func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error { - err := t.Impl.QueryLeftUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { - t.Impl.SetFederationAPI(fsAPI, keyRing) -} - -func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { - t.Impl.SetAppserviceAPI(asAPI) -} - -func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.RoomserverUserAPI) { - t.Impl.SetUserAPI(userAPI) -} - -func (t *RoomserverInternalAPITrace) InputRoomEvents( - ctx context.Context, - req *InputRoomEventsRequest, - res *InputRoomEventsResponse, -) error { - err := t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformInvite( - ctx context.Context, - req *PerformInviteRequest, - res *PerformInviteResponse, -) error { - err := t.Impl.PerformInvite(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformPeek( - ctx context.Context, - req *PerformPeekRequest, - res *PerformPeekResponse, -) error { - err := t.Impl.PerformPeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformUnpeek( - ctx context.Context, - req *PerformUnpeekRequest, - res *PerformUnpeekResponse, -) error { - err := t.Impl.PerformUnpeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformRoomUpgrade( - ctx context.Context, - req *PerformRoomUpgradeRequest, - res *PerformRoomUpgradeResponse, -) error { - err := t.Impl.PerformRoomUpgrade(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformJoin( - ctx context.Context, - req *PerformJoinRequest, - res *PerformJoinResponse, -) error { - err := t.Impl.PerformJoin(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformLeave( - ctx context.Context, - req *PerformLeaveRequest, - res *PerformLeaveResponse, -) error { - err := t.Impl.PerformLeave(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformLeave req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformPublish( - ctx context.Context, - req *PerformPublishRequest, - res *PerformPublishResponse, -) error { - err := t.Impl.PerformPublish(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( - ctx context.Context, - req *PerformAdminEvacuateRoomRequest, - res *PerformAdminEvacuateRoomResponse, -) error { - err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( - ctx context.Context, - req *PerformAdminEvacuateUserRequest, - res *PerformAdminEvacuateUserResponse, -) error { - err := t.Impl.PerformAdminEvacuateUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminDownloadState( - ctx context.Context, - req *PerformAdminDownloadStateRequest, - res *PerformAdminDownloadStateResponse, -) error { - err := t.Impl.PerformAdminDownloadState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminDownloadState req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformInboundPeek( - ctx context.Context, - req *PerformInboundPeekRequest, - res *PerformInboundPeekResponse, -) error { - err := t.Impl.PerformInboundPeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryPublishedRooms( - ctx context.Context, - req *QueryPublishedRoomsRequest, - res *QueryPublishedRoomsResponse, -) error { - err := t.Impl.QueryPublishedRooms(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryPublishedRooms req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryLatestEventsAndState( - ctx context.Context, - req *QueryLatestEventsAndStateRequest, - res *QueryLatestEventsAndStateResponse, -) error { - err := t.Impl.QueryLatestEventsAndState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryLatestEventsAndState req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( - ctx context.Context, - req *QueryStateAfterEventsRequest, - res *QueryStateAfterEventsResponse, -) error { - err := t.Impl.QueryStateAfterEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryStateAfterEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryEventsByID( - ctx context.Context, - req *QueryEventsByIDRequest, - res *QueryEventsByIDResponse, -) error { - err := t.Impl.QueryEventsByID(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryEventsByID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipForUser( - ctx context.Context, - req *QueryMembershipForUserRequest, - res *QueryMembershipForUserResponse, -) error { - err := t.Impl.QueryMembershipForUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipForUser req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom( - ctx context.Context, - req *QueryMembershipsForRoomRequest, - res *QueryMembershipsForRoomResponse, -) error { - err := t.Impl.QueryMembershipsForRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipsForRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryServerJoinedToRoom( - ctx context.Context, - req *QueryServerJoinedToRoomRequest, - res *QueryServerJoinedToRoomResponse, -) error { - err := t.Impl.QueryServerJoinedToRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerJoinedToRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent( - ctx context.Context, - req *QueryServerAllowedToSeeEventRequest, - res *QueryServerAllowedToSeeEventResponse, -) error { - err := t.Impl.QueryServerAllowedToSeeEvent(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerAllowedToSeeEvent req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMissingEvents( - ctx context.Context, - req *QueryMissingEventsRequest, - res *QueryMissingEventsResponse, -) error { - err := t.Impl.QueryMissingEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMissingEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryStateAndAuthChain( - ctx context.Context, - req *QueryStateAndAuthChainRequest, - res *QueryStateAndAuthChainResponse, -) error { - err := t.Impl.QueryStateAndAuthChain(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryStateAndAuthChain req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformBackfill( - ctx context.Context, - req *PerformBackfillRequest, - res *PerformBackfillResponse, -) error { - err := t.Impl.PerformBackfill(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformBackfill req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformForget( - ctx context.Context, - req *PerformForgetRequest, - res *PerformForgetResponse, -) error { - err := t.Impl.PerformForget(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformForget req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( - ctx context.Context, - req *QueryRoomVersionCapabilitiesRequest, - res *QueryRoomVersionCapabilitiesResponse, -) error { - err := t.Impl.QueryRoomVersionCapabilities(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionCapabilities req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRoomVersionForRoom( - ctx context.Context, - req *QueryRoomVersionForRoomRequest, - res *QueryRoomVersionForRoomResponse, -) error { - err := t.Impl.QueryRoomVersionForRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionForRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) SetRoomAlias( - ctx context.Context, - req *SetRoomAliasRequest, - res *SetRoomAliasResponse, -) error { - err := t.Impl.SetRoomAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("SetRoomAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) GetRoomIDForAlias( - ctx context.Context, - req *GetRoomIDForAliasRequest, - res *GetRoomIDForAliasResponse, -) error { - err := t.Impl.GetRoomIDForAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("GetRoomIDForAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) GetAliasesForRoomID( - ctx context.Context, - req *GetAliasesForRoomIDRequest, - res *GetAliasesForRoomIDResponse, -) error { - err := t.Impl.GetAliasesForRoomID(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("GetAliasesForRoomID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) RemoveRoomAlias( - ctx context.Context, - req *RemoveRoomAliasRequest, - res *RemoveRoomAliasResponse, -) error { - err := t.Impl.RemoveRoomAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("RemoveRoomAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error { - err := t.Impl.QueryCurrentState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryRoomsForUser retrieves a list of room IDs matching the given query. -func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error { - err := t.Impl.QueryRoomsForUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryBulkStateContent does a bulk query for state event content in the given rooms. -func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error { - err := t.Impl.QueryBulkStateContent(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res)) - return err -} - -// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. -func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error { - err := t.Impl.QuerySharedUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryKnownUsers returns a list of users that we know about from our joined rooms. -func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error { - err := t.Impl.QueryKnownUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. -func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { - err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryAuthChain( - ctx context.Context, - request *QueryAuthChainRequest, - response *QueryAuthChainResponse, -) error { - err := t.Impl.QueryAuthChain(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryAuthChain req=%+v res=%+v", js(request), js(response)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( - ctx context.Context, - request *QueryRestrictedJoinAllowedRequest, - response *QueryRestrictedJoinAllowedResponse, -) error { - err := t.Impl.QueryRestrictedJoinAllowed(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryRestrictedJoinAllowed req=%+v res=%+v", js(request), js(response)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent( - ctx context.Context, - request *QueryMembershipAtEventRequest, - response *QueryMembershipAtEventResponse, -) error { - err := t.Impl.QueryMembershipAtEvent(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response)) - return err -} - -func js(thing interface{}) string { - b, err := json.Marshal(thing) - if err != nil { - return fmt.Sprintf("Marshal error:%s", err) - } - return string(b) -} diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 36d0625c7..0c0f52c45 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -55,6 +55,8 @@ const ( OutputTypeNewInboundPeek OutputType = "new_inbound_peek" // OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek OutputTypeRetirePeek OutputType = "retire_peek" + // OutputTypePurgeRoom indicates the event is an OutputPurgeRoom + OutputTypePurgeRoom OutputType = "purge_room" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -78,6 +80,8 @@ type OutputEvent struct { NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"` // The content of event with type OutputTypeRetirePeek RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"` + // The content of the event with type OutputPurgeRoom + PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"` } // Type of the OutputNewRoomEvent. @@ -257,3 +261,7 @@ type OutputRetirePeek struct { UserID string DeviceID string } + +type OutputPurgeRoom struct { + RoomID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e789b9568..83cb0460a 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -241,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct { Error *PerformError } +type PerformAdminPurgeRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminPurgeRoomResponse struct { + Error *PerformError `json:"error,omitempty"` +} + type PerformAdminDownloadStateRequest struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 76f8298ca..4ef548e19 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -433,7 +433,7 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { return nil } -// QueryMembershipAtEventRequest requests the membership events for a user +// QueryMembershipAtEventRequest requests the membership event for a user // for a list of eventIDs. type QueryMembershipAtEventRequest struct { RoomID string @@ -443,9 +443,10 @@ type QueryMembershipAtEventRequest struct { // QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. type QueryMembershipAtEventResponse struct { - // Memberships is a map from eventID to a list of events (if any). Events that - // do not have known state will return an empty array here. - Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` + // Membership is a map from eventID to membership event. Events that + // do not have known state will return a nil event, resulting in a "leave" membership + // when calculating history visibility. + Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"` } // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 329e6af7f..fc61b7f4a 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -30,26 +30,6 @@ import ( "github.com/tidwall/sjson" ) -// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. -type RoomserverInternalAPIDatabase interface { - // Save a given room alias with the room ID it refers to. - // Returns an error if there was a problem talking to the database. - SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error - // Look up the room ID a given alias refers to. - // Returns an error if there was a problem talking to the database. - GetRoomIDForAlias(ctx context.Context, alias string) (string, error) - // Look up all aliases referring to a given room ID. - // Returns an error if there was a problem talking to the database. - GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) - // Remove a given room alias. - // Returns an error if there was a problem talking to the database. - RemoveRoomAlias(ctx context.Context, alias string) error - // Look up the room version for a given room. - GetRoomVersionForRoom( - ctx context.Context, roomID string, - ) (gomatrixserverlib.RoomVersion, error) -} - // SetRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) SetRoomAlias( ctx context.Context, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 451b37696..c43b9d049 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -155,7 +155,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.Unpeeker = &perform.Unpeeker{ ServerName: r.ServerName, Cfg: r.Cfg, - DB: r.DB, FSAPI: r.fsAPI, Inputer: r.Inputer, } diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 03d8bca0b..27c8dd8fa 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -31,7 +31,8 @@ import ( // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -45,16 +46,6 @@ func CheckForSoftFail( return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) } } else { - // Work out if the room exists. - var roomInfo *types.RoomInfo - roomInfo, err = db.RoomInfo(ctx, event.RoomID()) - if err != nil { - return false, fmt.Errorf("db.RoomNID: %w", err) - } - if roomInfo == nil || roomInfo.IsStub() { - return false, nil - } - // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. roomState := state.NewStateResolution(db, roomInfo) @@ -76,7 +67,7 @@ func CheckForSoftFail( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -93,7 +84,8 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomNID types.RoomNID, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -108,7 +100,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, + roomNID types.RoomNID, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -223,7 +216,7 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 7efad7af6..ee1610cf2 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, eventNIDs) + events, err := db.Events(ctx, info.RoomNID, eventNIDs) if err != nil { return false, err } @@ -157,7 +157,7 @@ func IsInvitePending( // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs types.EventNIDs @@ -177,7 +177,7 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomNID, eventNIDs) if err != nil { return nil, err } @@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } func LoadEvents( - ctx context.Context, db storage.Database, eventNIDs []types.EventNID, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomNID, eventNIDs) if err != nil { return nil, err } @@ -242,13 +242,13 @@ func LoadEvents( } func LoadStateEvents( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return LoadEvents(ctx, db, eventNIDs) + return LoadEvents(ctx, db, roomNID, eventNIDs) } func CheckServerAllowedToSeeEvent( @@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState( return nil, nil } - return LoadStateEvents(ctx, db, filteredEntries) + return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function @@ -366,7 +366,7 @@ BFSLoop: next = make([]string, 0) } // Retrieve the events to process from the database. - events, err = db.EventsFromIDs(ctx, front) + events, err = db.EventsFromIDs(ctx, info.RoomNID, front) if err != nil { return resultNIDs, redactEventIDs, err } @@ -467,7 +467,7 @@ func QueryLatestEventsAndState( return err } - stateEvents, err := LoadStateEvents(ctx, db, stateEntries) + stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index aa5c30e44..62730df1f 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -38,7 +38,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) + assert.NoError(t, err) + assert.Greater(t, roomNID, types.RoomNID(0)) + + eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) + assert.NoError(t, err) + assert.Greater(t, eventTypeNID, types.EventTypeNID(0)) + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) + assert.NoError(t, err) + + evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 941311030..2ec19f010 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -76,7 +76,7 @@ type Inputer struct { Cfg *config.RoomServer Base *base.BaseDendrite ProcessContext *process.ProcessContext - DB storage.Database + DB storage.RoomDatabase NATSClient *nats.Conn JetStream nats.JetStreamContext Durable nats.SubOpt diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4179fc1ef..fe35efb27 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,6 +24,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" @@ -40,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -166,6 +166,7 @@ func (r *Inputer) processRoomEvent( missingPrev = !input.HasState && len(missingPrevIDs) > 0 } + // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -200,59 +201,8 @@ func (r *Inputer) processRoomEvent( } } - // First of all, check that the auth events of the event are known. - // If they aren't then we will ask the federation API for them. isRejected := false - authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) - } - - // Check if the event is allowed by its auth events. If it isn't then - // we consider the event to be "rejected" — it will still be persisted. var rejectionErr error - if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { - isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) - } - - // Accumulate the auth event NIDs. - authEventIDs := event.AuthEventIDs() - authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownEvents[authEventID]; !ok { - // Unknown auth events only really matter if the event actually failed - // auth. If it passed auth then we can assume that everything that was - // known was sufficient, even if extraneous auth events were specified - // but weren't found. - if isRejected { - if event.StateKey() != nil { - return fmt.Errorf( - "missing auth event %s for state event %s (type %q, state key %q)", - authEventID, event.EventID(), event.Type(), *event.StateKey(), - ) - } else { - return fmt.Errorf( - "missing auth event %s for timeline event %s (type %q)", - authEventID, event.EventID(), event.Type(), - ) - } - } - } else { - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) - } - } - - var softfail bool - if input.Kind == api.KindNew { - // Check that the event passes authentication checks based on the - // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) - if err != nil { - logger.WithError(err).Warn("Error authing soft-failed event") - } - } // At this point we are checking whether we know all of the prev events, and // if we know the state before the prev events. This is necessary before we @@ -314,13 +264,66 @@ func (r *Inputer) processRoomEvent( } } + // Check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + isRejected = true + rejectionErr = err + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) + } + } + + var softfail bool + if input.Kind == api.KindNew && !isCreateEvent { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) + if err != nil { + logger.WithError(err).Warn("Error authing soft-failed event") + } + } + // Get the state before the event so that we can work out if the event was // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. - if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { - historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) + if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { + historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) } @@ -329,8 +332,23 @@ func (r *Inputer) processRoomEvent( } } + roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + } + + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -471,6 +489,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse // nolint:nakedret func (r *Inputer) processStateBefore( ctx context.Context, + roomNID types.RoomNID, input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { @@ -486,7 +505,7 @@ func (r *Inputer) processStateBefore( case input.HasState: // If we're overriding the state then we need to go and retrieve // them from the database. It's a hard error if they are missing. - stateEvents, err := r.DB.EventsFromIDs(ctx, input.StateEventIDs) + stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } @@ -564,6 +583,7 @@ func (r *Inputer) processStateBefore( // we've failed to retrieve the auth chain altogether (in which case // an error is returned) or we've successfully retrieved them all and // they are now in the database. +// nolint: gocyclo func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, @@ -584,7 +604,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -670,8 +690,23 @@ nextAuthEvent: logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } + roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + } + + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -747,7 +782,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } - memberEvents, err := r.DB.Events(ctx, membershipNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) if err != nil { return err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 28a54623b..99a012551 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, 0, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 03ac2b38d..c8b7d31dd 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -43,7 +43,7 @@ type missingStateReq struct { log *logrus.Entry virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName - db storage.Database + db storage.RoomDatabase roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier @@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, stateEventNIDs) + stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil @@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even missingEventList = append(missingEventList, evID) } t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) if err != nil { return nil } @@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) if err != nil { return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } @@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs if localFirst { // fetch from the roomserver - events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index d42f4e45d..2efe2255f 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type Admin struct { @@ -69,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - memberEvents, err := r.DB.Events(ctx, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } +func (r *Admin) PerformAdminPurgeRoom( + ctx context.Context, + req *api.PerformAdminPurgeRoomRequest, + res *api.PerformAdminPurgeRoomResponse, +) error { + // Validate we actually got a room ID and nothing else + if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Malformed room ID: %s", err), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: err.Error(), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + + return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + { + Type: api.OutputTypePurgeRoom, + PurgeRoom: &api.OutputPurgeRoom{ + RoomID: req.RoomID, + }, + }, + }) +} + func (r *Admin) PerformAdminDownloadState( ctx context.Context, req *api.PerformAdminDownloadStateRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d9214fdc6..3a3a049db 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -86,7 +86,7 @@ func (r *Backfiller) PerformBackfill( // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. var loadedEvents []*gomatrixserverlib.Event - loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { return r.backfillViaFederation(ctx, request, response) @@ -258,6 +258,7 @@ type backfillRequester struct { eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]*gomatrixserverlib.Event historyVisiblity gomatrixserverlib.HistoryVisibility + roomInfo types.RoomInfo } func newBackfillRequester( @@ -454,14 +455,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -472,7 +473,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -523,11 +524,15 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } eventNIDs := make([]types.EventNID, len(nidMap)) i := 0 + roomNID := b.roomInfo.RoomNID for _, nid := range nidMap { - eventNIDs[i] = nid + eventNIDs[i] = nid.EventNID i++ + if roomNID == 0 { + roomNID = nid.RoomNID + } } - eventsWithNids, err := b.db.Events(ctx, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -544,7 +549,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID @@ -557,7 +562,7 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -570,21 +575,17 @@ func joinEventsFromHistoryVisibility( // Can we see events in the room? canSeeEvents := auth.IsServerAllowed(thisServer, true, events) - visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events)) + visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) return nil, visibility, nil } // get joined members - info, err := db.RoomInfo(ctx, roomID) - if err != nil { - return nil, visibility, nil - } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, false) if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) return evs, visibility, err } @@ -601,12 +602,31 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs authNids := make([]types.EventNID, len(nidMap)) i := 0 for _, nid := range nidMap { - authNids[i] = nid + authNids[i] = nid.EventNID i++ } + + roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) + if err != nil { + logrus.WithError(err).Error("failed to get or create roomNID") + continue + } + + eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventType NID") + continue + } + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventStateKey NID") + continue + } + var redactedEventID string var redactionEvent *gomatrixserverlib.Event - eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) + eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 29decd363..9ac9edc4c 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -29,7 +29,7 @@ import ( ) type InboundPeeker struct { - DB storage.Database + DB storage.RoomDatabase Inputer *input.Inputer } @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) if err != nil { return err } @@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index f60247cd7..118e1b879 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -291,7 +291,7 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, stateNIDs) + stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index fa998e3e1..86f1dfaee 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -109,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID( // mimic the returned values from Synapse res.Message = "You cannot reject this invite" res.Code = 403 - return nil, fmt.Errorf("You cannot reject this invite") + return nil, jsonerror.LeaveServerNoticeError() } } } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 0d97da4d6..4d714be66 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -22,7 +22,6 @@ import ( fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,9 +30,7 @@ type Unpeeker struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI - DB storage.Database - - Inputer *input.Inputer + Inputer *input.Inputer } // PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 69d841dda..ac34e0ff0 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -101,7 +102,7 @@ func (r *Queryer) QueryStateAfterEvents( return err } - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) if err != nil { return err } @@ -137,17 +138,7 @@ func (r *Queryer) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) - if err != nil { - return err - } - - var eventNIDs []types.EventNID - for _, nid := range eventNIDMap { - eventNIDs = append(eventNIDs, nid) - } - - events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) + events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) if err != nil { return err } @@ -195,7 +186,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -216,7 +207,8 @@ func (r *Queryer) QueryMembershipAtEvent( request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse, ) error { - response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return fmt.Errorf("unable to get roomInfo: %w", err) @@ -234,7 +226,17 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) } - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) + switch err { + case nil: + return nil + case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event + default: + return err + } + + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -258,7 +260,7 @@ func (r *Queryer) QueryMembershipAtEvent( for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] if !ok || len(stateEntry) == 0 { - response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} + response.Membership[eventID] = nil continue } @@ -266,24 +268,24 @@ func (r *Queryer) QueryMembershipAtEvent( // once. If we have more than one membership event, we need to get the state for each state entry. if canShortCircuit { if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) } } else { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) } - res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) - + // Iterate over all membership events we got. Given we only query the membership for + // one user and assuming this user only ever has one membership event associated to + // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - res = append(res, ev.Headered(info.RoomVersion)) + response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) } } - response.Memberships[eventID] = res } return nil @@ -316,7 +318,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -355,14 +357,14 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) } if err != nil { @@ -413,7 +415,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, ) (err error) { - events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) + events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) if err != nil { return } @@ -464,7 +466,7 @@ func (r *Queryer) QueryMissingEvents( eventsToFilter[id] = true } } - events, err := r.DB.EventsFromIDs(ctx, front) + events, err := r.DB.EventsFromIDs(ctx, 0, front) if err != nil { return err } @@ -484,7 +486,7 @@ func (r *Queryer) QueryMissingEvents( return err } - loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) if err != nil { return err } @@ -609,11 +611,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI return nil, rejected, false, err } - events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) return events, rejected, false, err } -type eventsFromIDs func(context.Context, []string) ([]types.Event, error) +type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) // GetAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and @@ -631,7 +633,7 @@ func GetAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := fn(ctx, eventsToFetch) + events, err := fn(ctx, 0, eventsToFetch) if err != nil { return nil, err } @@ -969,7 +971,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 03627ea97..167611575 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { } // EventsFromIDs implements RoomserverInternalAPIEventDB -func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { +func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go deleted file mode 100644 index 8a2e0a03c..000000000 --- a/roomserver/inthttp/client.go +++ /dev/null @@ -1,563 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/gomatrixserverlib" - - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsInputAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/roomserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -const ( - // Alias operations - RoomserverSetRoomAliasPath = "/roomserver/setRoomAlias" - RoomserverGetRoomIDForAliasPath = "/roomserver/GetRoomIDForAlias" - RoomserverGetAliasesForRoomIDPath = "/roomserver/GetAliasesForRoomID" - RoomserverGetCreatorIDForAliasPath = "/roomserver/GetCreatorIDForAlias" - RoomserverRemoveRoomAliasPath = "/roomserver/removeRoomAlias" - - // Input operations - RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" - - // Perform operations - RoomserverPerformInvitePath = "/roomserver/performInvite" - RoomserverPerformPeekPath = "/roomserver/performPeek" - RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" - RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" - RoomserverPerformJoinPath = "/roomserver/performJoin" - RoomserverPerformLeavePath = "/roomserver/performLeave" - RoomserverPerformBackfillPath = "/roomserver/performBackfill" - RoomserverPerformPublishPath = "/roomserver/performPublish" - RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" - RoomserverPerformForgetPath = "/roomserver/performForget" - RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" - RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" - RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState" - - // Query operations - RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" - RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" - RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" - RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" - RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" - RoomserverQueryServerJoinedToRoomPath = "/roomserver/queryServerJoinedToRoomPath" - RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" - RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" - RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" - RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" - RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" - RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" - RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState" - RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser" - RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" - RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" - RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" - RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" - RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" - RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" - RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" - RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers" -) - -type httpRoomserverInternalAPI struct { - roomserverURL string - httpClient *http.Client - cache caching.RoomVersionCache -} - -// NewRoomserverClient creates a RoomserverInputAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewRoomserverClient( - roomserverURL string, - httpClient *http.Client, - cache caching.RoomVersionCache, -) (api.RoomserverInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverInternalAPIHTTP: httpClient is ") - } - return &httpRoomserverInternalAPI{ - roomserverURL: roomserverURL, - httpClient: httpClient, - cache: cache, - }, nil -} - -// SetFederationInputAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { -} - -// SetAppserviceAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { -} - -// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { -} - -// SetRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) SetRoomAlias( - ctx context.Context, - request *api.SetRoomAliasRequest, - response *api.SetRoomAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// GetRoomIDForAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( - ctx context.Context, - request *api.GetRoomIDForAliasRequest, - response *api.GetRoomIDForAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// GetAliasesForRoomID implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( - ctx context.Context, - request *api.GetAliasesForRoomIDRequest, - response *api.GetAliasesForRoomIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath, - h.httpClient, ctx, request, response, - ) -} - -// RemoveRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) RemoveRoomAlias( - ctx context.Context, - request *api.RemoveRoomAliasRequest, - response *api.RemoveRoomAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// InputRoomEvents implements RoomserverInputAPI -func (h *httpRoomserverInternalAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - if err := httputil.CallInternalRPCAPI( - "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath, - h.httpClient, ctx, request, response, - ); err != nil { - response.ErrMsg = err.Error() - } - return nil -} - -func (h *httpRoomserverInternalAPI) PerformInvite( - ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformJoin( - ctx context.Context, - request *api.PerformJoinRequest, - response *api.PerformJoinResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformPeek( - ctx context.Context, - request *api.PerformPeekRequest, - response *api.PerformPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformInboundPeek( - ctx context.Context, - request *api.PerformInboundPeekRequest, - response *api.PerformInboundPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformUnpeek( - ctx context.Context, - request *api.PerformUnpeekRequest, - response *api.PerformUnpeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformRoomUpgrade( - ctx context.Context, - request *api.PerformRoomUpgradeRequest, - response *api.PerformRoomUpgradeResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformLeave( - ctx context.Context, - request *api.PerformLeaveRequest, - response *api.PerformLeaveResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformPublish( - ctx context.Context, - request *api.PerformPublishRequest, - response *api.PerformPublishResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( - ctx context.Context, - request *api.PerformAdminEvacuateRoomRequest, - response *api.PerformAdminEvacuateRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminDownloadState( - ctx context.Context, - request *api.PerformAdminDownloadStateRequest, - response *api.PerformAdminDownloadStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminDownloadState", h.roomserverURL+RoomserverPerformAdminDownloadStatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( - ctx context.Context, - request *api.PerformAdminEvacuateUserRequest, - response *api.PerformAdminEvacuateUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryLatestEventsAndState implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath, - h.httpClient, ctx, request, response, - ) -} - -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryEventsByID implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryPublishedRooms( - ctx context.Context, - request *api.QueryPublishedRoomsRequest, - response *api.QueryPublishedRoomsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipForUser implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMembershipForUser( - ctx context.Context, - request *api.QueryMembershipForUserRequest, - response *api.QueryMembershipForUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipsForRoom implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( - ctx context.Context, - request *api.QueryMembershipsForRoomRequest, - response *api.QueryMembershipsForRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipsForRoom implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( - ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - return httputil.CallInternalRPCAPI( - "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMissingEvents implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryMissingEvents( - ctx context.Context, - request *api.QueryMissingEventsRequest, - response *api.QueryMissingEventsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryStateAndAuthChain implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( - ctx context.Context, - request *api.QueryStateAndAuthChainRequest, - response *api.QueryStateAndAuthChainResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath, - h.httpClient, ctx, request, response, - ) -} - -// PerformBackfill implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryRoomVersionCapabilities implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( - ctx context.Context, - request *api.QueryRoomVersionCapabilitiesRequest, - response *api.QueryRoomVersionCapabilitiesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryRoomVersionForRoom implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - if roomVersion, ok := h.cache.GetRoomVersion(request.RoomID); ok { - response.RoomVersion = roomVersion - return nil - } - err := httputil.CallInternalRPCAPI( - "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath, - h.httpClient, ctx, request, response, - ) - if err == nil { - h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - } - return err -} - -func (h *httpRoomserverInternalAPI) QueryCurrentState( - ctx context.Context, - request *api.QueryCurrentStateRequest, - response *api.QueryCurrentStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryRoomsForUser( - ctx context.Context, - request *api.QueryRoomsForUserRequest, - response *api.QueryRoomsForUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryBulkStateContent( - ctx context.Context, - request *api.QueryBulkStateContentRequest, - response *api.QueryBulkStateContentResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QuerySharedUsers( - ctx context.Context, - request *api.QuerySharedUsersRequest, - response *api.QuerySharedUsersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryKnownUsers( - ctx context.Context, - request *api.QueryKnownUsersRequest, - response *api.QueryKnownUsersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryAuthChain( - ctx context.Context, - request *api.QueryAuthChainRequest, - response *api.QueryAuthChainResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( - ctx context.Context, - request *api.QueryServerBannedFromRoomRequest, - response *api.QueryServerBannedFromRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( - ctx context.Context, - request *api.QueryRestrictedJoinAllowedRequest, - response *api.QueryRestrictedJoinAllowedResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformForget( - ctx context.Context, - request *api.PerformForgetRequest, - response *api.PerformForgetResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformForget", h.roomserverURL+RoomserverPerformForgetPath, - h.httpClient, ctx, request, response, - ) - -} - -func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error { - return httputil.CallInternalRPCAPI( - "QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error { - return httputil.CallInternalRPCAPI( - "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go deleted file mode 100644 index 4d21909b7..000000000 --- a/roomserver/inthttp/server.go +++ /dev/null @@ -1,211 +0,0 @@ -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/roomserver/api" -) - -// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. -// nolint: gocyclo -func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { - internalAPIMux.Handle( - RoomserverInputRoomEventsPath, - httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", enableMetrics, r.InputRoomEvents), - ) - - internalAPIMux.Handle( - RoomserverPerformInvitePath, - httputil.MakeInternalRPCAPI("RoomserverPerformInvite", enableMetrics, r.PerformInvite), - ) - - internalAPIMux.Handle( - RoomserverPerformJoinPath, - httputil.MakeInternalRPCAPI("RoomserverPerformJoin", enableMetrics, r.PerformJoin), - ) - - internalAPIMux.Handle( - RoomserverPerformLeavePath, - httputil.MakeInternalRPCAPI("RoomserverPerformLeave", enableMetrics, r.PerformLeave), - ) - - internalAPIMux.Handle( - RoomserverPerformPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPeek", enableMetrics, r.PerformPeek), - ) - - internalAPIMux.Handle( - RoomserverPerformInboundPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", enableMetrics, r.PerformInboundPeek), - ) - - internalAPIMux.Handle( - RoomserverPerformUnpeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", enableMetrics, r.PerformUnpeek), - ) - - internalAPIMux.Handle( - RoomserverPerformRoomUpgradePath, - httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", enableMetrics, r.PerformRoomUpgrade), - ) - - internalAPIMux.Handle( - RoomserverPerformPublishPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPublish", enableMetrics, r.PerformPublish), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", enableMetrics, r.PerformAdminEvacuateRoom), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminDownloadStatePath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), - ) - - internalAPIMux.Handle( - RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", enableMetrics, r.QueryPublishedRooms), - ) - - internalAPIMux.Handle( - RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", enableMetrics, r.QueryLatestEventsAndState), - ) - - internalAPIMux.Handle( - RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", enableMetrics, r.QueryStateAfterEvents), - ) - - internalAPIMux.Handle( - RoomserverQueryEventsByIDPath, - httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", enableMetrics, r.QueryEventsByID), - ) - - internalAPIMux.Handle( - RoomserverQueryMembershipForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", enableMetrics, r.QueryMembershipForUser), - ) - - internalAPIMux.Handle( - RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", enableMetrics, r.QueryMembershipsForRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", enableMetrics, r.QueryServerJoinedToRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", enableMetrics, r.QueryServerAllowedToSeeEvent), - ) - - internalAPIMux.Handle( - RoomserverQueryMissingEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", enableMetrics, r.QueryMissingEvents), - ) - - internalAPIMux.Handle( - RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", enableMetrics, r.QueryStateAndAuthChain), - ) - - internalAPIMux.Handle( - RoomserverPerformBackfillPath, - httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", enableMetrics, r.PerformBackfill), - ) - - internalAPIMux.Handle( - RoomserverPerformForgetPath, - httputil.MakeInternalRPCAPI("RoomserverPerformForget", enableMetrics, r.PerformForget), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", enableMetrics, r.QueryRoomVersionCapabilities), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", enableMetrics, r.QueryRoomVersionForRoom), - ) - - internalAPIMux.Handle( - RoomserverSetRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", enableMetrics, r.SetRoomAlias), - ) - - internalAPIMux.Handle( - RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", enableMetrics, r.GetRoomIDForAlias), - ) - - internalAPIMux.Handle( - RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", enableMetrics, r.GetAliasesForRoomID), - ) - - internalAPIMux.Handle( - RoomserverRemoveRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", enableMetrics, r.RemoveRoomAlias), - ) - - internalAPIMux.Handle( - RoomserverQueryCurrentStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", enableMetrics, r.QueryCurrentState), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomsForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", enableMetrics, r.QueryRoomsForUser), - ) - - internalAPIMux.Handle( - RoomserverQueryBulkStateContentPath, - httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", enableMetrics, r.QueryBulkStateContent), - ) - - internalAPIMux.Handle( - RoomserverQuerySharedUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", enableMetrics, r.QuerySharedUsers), - ) - - internalAPIMux.Handle( - RoomserverQueryKnownUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", enableMetrics, r.QueryKnownUsers), - ) - - internalAPIMux.Handle( - RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", enableMetrics, r.QueryServerBannedFromRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", enableMetrics, r.QueryAuthChain), - ) - - internalAPIMux.Handle( - RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", enableMetrics, r.QueryRestrictedJoinAllowed), - ) - internalAPIMux.Handle( - RoomserverQueryMembershipAtEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), - ) - - internalAPIMux.Handle( - RoomserverQueryLeftMembersPath, - httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers), - ) -} diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 0f6b48bf9..5a8d8b570 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -15,24 +15,15 @@ package roomserver import ( - "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal" - "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI, enableMetrics bool) { - inthttp.AddRoutes(intAPI, router, enableMetrics) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +// NewInternalAPI returns a concrete implementation of the internal API. func NewInternalAPI( base *base.BaseDendrite, ) api.RoomserverInternalAPI { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 595ceb526..304311c4c 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -2,23 +2,22 @@ package roomserver_test import ( "context" - "net/http" "reflect" "testing" "time" - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/userapi" userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -47,7 +46,7 @@ func TestUsers(t *testing.T) { }) t.Run("kick users", func(t *testing.T) { - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) rsAPI.SetUserAPI(usrAPI) testKickUsers(t, rsAPI, usrAPI) }) @@ -203,23 +202,166 @@ func Test_QueryLeftUsers(t *testing.T) { } } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - roomserver.AddInternalRoutes(router, rsAPI, false) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil) - if err != nil { - t.Fatalf("failed to create HTTP client") - } - testCase(httpAPI) - }) - t.Run("Monolith", func(t *testing.T) { - testCase(rsAPI) - // also test tracing - traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI} - testCase(traceAPI) - }) - + testCase(rsAPI) + }) +} + +func TestPurgeRoom(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, close := mustCreateDatabase(t, dbType) + defer close() + + jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // some dummy entries to validate after purging + publishResp := &api.PerformPublishResponse{} + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + t.Fatal(err) + } + if publishResp.Error != nil { + t.Fatal(publishResp.Error) + } + + isPublished, err := db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if !isPublished { + t.Fatalf("room should be published before purging") + } + + aliasResp := &api.SetRoomAliasResponse{} + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + t.Fatal(err) + } + // check the alias is actually there + aliasesResp := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil { + t.Fatal(err) + } + wantAliases := 1 + if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases { + t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases) + } + + // validate the room exists before purging + roomInfo, err := db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo == nil { + t.Fatalf("room does not exist") + } + // remember the roomInfo before purging + existingRoomInfo := roomInfo + + // validate there is an invite for bob + nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID}) + if err != nil { + t.Fatal(err) + } + bobNID, ok := nids[bob.ID] + if !ok { + t.Fatalf("%s does not exist", bob.ID) + } + + _, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + wantInviteCount := 1 + if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + if inviteEventIDs[0] != inviteEvent.EventID() { + t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0]) + } + + // purge the room from the database + purgeResp := &api.PerformAdminPurgeRoomResponse{} + if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + t.Fatal(err) + } + + // wait for all consumers to process the purge event + var sum = 1 + timeout := time.Second * 5 + deadline, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for sum > 0 { + if deadline.Err() != nil { + t.Fatalf("test timed out after %s", timeout) + } + sum = 0 + consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + for x := range consumerCh { + sum += x.NumAckPending + } + time.Sleep(time.Millisecond) + } + + roomInfo, err = db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo != nil { + t.Fatalf("room should not exist after purging: %+v", roomInfo) + } + + // validation below + + // There should be no invite left + _, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + + if inviteCount := len(inviteEventIDs); inviteCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + + // aliases should be deleted + aliases, err := db.GetAliasesForRoomID(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if aliasCount := len(aliases); aliasCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount) + } + + // published room should be deleted + isPublished, err = db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if isPublished { + t.Fatalf("room should not be published after purging") + } }) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1cfde5e4b..cec542d7e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -41,8 +41,8 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2( // Store the newly found auth events in the auth set for this event. var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents) + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents) if err != nil { return err } @@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, eventNIDs) + events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs) if err != nil { return nil, nil, err } @@ -1120,7 +1120,7 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event, + ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() @@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents( // If we need to get events from the database, go and fetch // those now. if len(l.lookupFromDB) > 0 { - eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB) + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB) if err != nil { return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 92bc2e66f..88ec56670 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -69,15 +69,12 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, - isRejected bool, - ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -87,7 +84,7 @@ type Database interface { EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) // Look up the numeric IDs for a list of events. // Returns an error if there was a problem talking to the database. - EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error) // Set the state at an event. FIXME TODO: "at" SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error // Lookup the event IDs for a batch of event numeric IDs. @@ -138,7 +135,7 @@ type Database interface { // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. - EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. @@ -173,5 +170,45 @@ type Database interface { GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) + PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error + + // GetMembershipForHistoryVisibility queries the membership events for the given eventIDs. + // Returns a map from (input) eventID -> membership event. If no membership event is found, returns an empty event, resulting in + // a membership of "leave" when calculating history visibility. + GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) + GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) + GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) +} + +type RoomDatabase interface { + // RoomInfo returns room information for the given room ID, or nil if there is no room. + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + // IsEventRejected returns true if the event is known and rejected. + IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) + MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) + // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) + GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) } diff --git a/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go new file mode 100644 index 000000000..add66446b --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpPulishedAppservicePrimaryKey(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published RENAME CONSTRAINT roomserver_published_pkey TO roomserver_published_pkeyold; +CREATE UNIQUE INDEX roomserver_published_pkey ON roomserver_published (room_id, appservice_id, network_id); +ALTER TABLE roomserver_published DROP CONSTRAINT roomserver_published_pkeyold; +ALTER TABLE roomserver_published ADD PRIMARY KEY USING INDEX roomserver_published_pkey;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 9b5ed6eda..c935608a5 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -69,6 +69,9 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( auth_event_nids BIGINT[] NOT NULL, is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); + +-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility) +CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5); ` const insertEventSQL = "" + @@ -137,10 +140,10 @@ const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" const bulkSelectEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1)" const bulkSelectUnsentEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -517,20 +520,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, false) } // BulkSelectEventNIDs returns a map from string event ID to numeric event ID // only for events that haven't already been sent to the roomserver output. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, true) } // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) { var stmt *sql.Stmt if onlyUnsent { stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) @@ -542,14 +545,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") - results := make(map[string]types.EventNID, len(eventIDs)) + results := make(map[string]types.EventMetadata, len(eventIDs)) var eventID string var eventNID int64 + var roomNID int64 for rows.Next() { - if err = rows.Scan(&eventID, &eventNID); err != nil { + if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil { return nil, err } - results[eventID] = types.EventNID(eventNID) + results[eventID] = types.EventMetadata{ + EventNID: types.EventNID(eventNID), + RoomNID: types.RoomNID(roomNID), + } } return results, rows.Err() } diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 61caccb0e..eca81d81f 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -65,10 +65,16 @@ func CreatePublishedTable(db *sql.DB) error { return err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "roomserver: published appservice", - Up: deltas.UpPulishedAppservice, - }) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: published appservice", + Up: deltas.UpPulishedAppservice, + }, + { + Version: "roomserver: published appservice pkey", + Up: deltas.UpPulishedAppservicePrimaryKey, + }, + }...) return m.Up(context.Background()) } diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go new file mode 100644 index 000000000..efba439bd --- /dev/null +++ b/roomserver/storage/postgres/purge_statements.go @@ -0,0 +1,133 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid = ANY(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" + + " SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateBlockEntriesSQL = "" + + "DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" + + " SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" + + ")" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateBlockEntriesStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt +} + +func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { + s := &purgeStatements{} + + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + {&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a00c026f4..0e83cfc25 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,10 +21,10 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = ` AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); ` +// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events. +// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON. +const bulkSelectMembershipForHistoryVisibilitySQL = ` +SELECT re.event_id, re2.event_id, rej.event_json +FROM roomserver_events re +LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid +CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid) +LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid +CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid) +JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1 +LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid +WHERE re.event_id = ANY($2) + +` + type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt - bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, + {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL}, }.Prepare(db) } @@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( } return results, rows.Err() } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + var evJson []byte + var eventID string + var membershipEventID string + + knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + + for rows.Next() { + if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil { + return nil, err + } + if len(evJson) == 0 { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + continue + } + // If we already know this event, don't try to marshal the json again + if ev, ok := knownEvents[membershipEventID]; ok { + result[eventID] = ev + continue + } + event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion) + if err != nil { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + // not fatal + continue + } + he := event.Headered(roomInfo.RoomVersion) + result[eventID] = he + knownEvents[membershipEventID] = he + } + return result, rows.Err() +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 23a5f79eb..872084383 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -189,6 +189,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, Cache: cache, @@ -206,6 +210,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index cc880a6c8..5006c3c55 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -116,10 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent }) } -func (u *RoomUpdater) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - return u.d.events(ctx, u.txn, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) { + return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -197,12 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) -} - -func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter) } // IsReferenced implements types.RoomRecentEventsUpdater diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 725cc5bc7..aac5bc365 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -43,6 +44,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published RedactionsTable tables.Redactions + Purge tables.Purge GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -50,6 +52,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return true } +func (d *Database) GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) +} + func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { @@ -60,12 +68,25 @@ func (d *Database) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) - if err != nil { - return nil, err + // first try the cache + fetchEventTypes := make([]string, 0, len(eventTypes)) + for _, eventType := range eventTypes { + eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType) + if ok { + result[eventType] = eventTypeNID + continue + } + fetchEventTypes = append(fetchEventTypes, eventType) } - for eventType, nid := range nids { - result[eventType] = nid + if len(fetchEventTypes) > 0 { + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, fetchEventTypes) + if err != nil { + return nil, err + } + for eventType, nid := range nids { + result[eventType] = nid + d.Cache.StoreEventTypeKey(nid, eventType) + } } return result, nil } @@ -82,13 +103,15 @@ func (d *Database) EventStateKeys( fetch = append(fetch, nid) } } - fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) - if err != nil { - return nil, err - } - for nid, key := range fromDB { - result[nid] = key - d.Cache.StoreEventStateKey(nid, key) + if len(fetch) > 0 { + fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) + if err != nil { + return nil, err + } + for nid, key := range fromDB { + result[nid] = key + d.Cache.StoreEventStateKey(nid, key) + } } return result, nil } @@ -104,16 +127,32 @@ func (d *Database) eventStateKeyNIDs( ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) eventStateKeys = util.UniqueStrings(eventStateKeys) - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) - if err != nil { - return nil, err + // first ask the cache about these keys + fetchEventStateKeys := make([]string, 0, len(eventStateKeys)) + for _, eventStateKey := range eventStateKeys { + eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) + if ok { + result[eventStateKey] = eventStateKeyNID + continue + } + fetchEventStateKeys = append(fetchEventStateKeys, eventStateKey) } - for eventStateKey, nid := range nids { - result[eventStateKey] = nid + + if len(fetchEventStateKeys) > 0 { + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, fetchEventStateKeys) + if err != nil { + return nil, err + } + for eventStateKey, nid := range nids { + result[eventStateKey] = nid + d.Cache.StoreEventStateKey(nid, eventStateKey) + } } + // We received some nids, but are still missing some, work out which and create them if len(eventStateKeys) > len(result) { var nid types.EventStateKeyNID + var err error err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { for _, eventStateKey := range eventStateKeys { if _, ok := result[eventStateKey]; ok { @@ -255,7 +294,7 @@ func (d *Database) addState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { +) (map[string]types.EventMetadata, error) { return d.eventNIDs(ctx, nil, eventIDs, NoFilter) } @@ -268,7 +307,7 @@ const ( func (d *Database) eventNIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, -) (map[string]types.EventNID, error) { +) (map[string]types.EventMetadata, error) { switch filter { case FilterUnsentOnly: return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) @@ -318,11 +357,11 @@ func (d *Database) EventIDs( return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) +func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { + return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err @@ -330,10 +369,16 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st var nids []types.EventNID for _, nid := range nidMap { - nids = append(nids, nid) + nids = append(nids, nid.EventNID) + if roomNID != 0 && roomNID != nid.RoomNID { + logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID) + } + if roomNID == 0 { + roomNID = nid.RoomNID + } } - return d.events(ctx, txn, nids) + return d.events(ctx, txn, roomNID, nids) } func (d *Database) LatestEventIDs( @@ -473,14 +518,18 @@ func (d *Database) GetInvitesForUser( } func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID, ) ([]types.Event, error) { - return d.events(ctx, nil, eventNIDs) + return d.events(ctx, nil, roomNID, eventNIDs) } func (d *Database) events( - ctx context.Context, txn *sql.Tx, inputEventNIDs types.EventNIDs, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { + if roomNID == 0 { + // No need to go further, as we won't find any events for this room. + return nil, nil + } sort.Sort(inputEventNIDs) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) @@ -512,40 +561,34 @@ func (d *Database) events( if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) + eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) - if err != nil { - return nil, err - } - uniqueRoomNIDs := make(map[types.RoomNID]struct{}) - for _, n := range roomNIDs { - uniqueRoomNIDs[n] = struct{}{} - } - roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) - fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs)) - for n := range uniqueRoomNIDs { - if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - roomVersions[n] = roomVersion - continue - } + + var roomVersion gomatrixserverlib.RoomVersion + var fetchRoomVersion bool + var ok bool + var roomID string + if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok { + roomVersion, ok = d.Cache.GetRoomVersion(roomID) + if !ok { + fetchRoomVersion = true } - fetchNIDList = append(fetchNIDList, n) } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) - if err != nil { - return nil, err - } - for n, v := range dbRoomVersions { - roomVersions[n] = v + + if roomVersion == "" || fetchRoomVersion { + var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion + dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID}) + if err != nil { + return nil, err + } + if roomVersion, ok = dbRoomVersions[roomNID]; !ok { + return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID) + } } + for _, eventJSON := range eventJSONs { - roomNID := roomNIDs[eventJSON.EventNID] - roomVersion := roomVersions[roomNID] events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, ) @@ -617,73 +660,71 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } -func (d *Database) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +// GetOrCreateRoomNID gets or creates a new roomNID for the given event +func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) { + // Get the default room version. If the client doesn't supply a room_version + // then we will use our configured default to create the room. + // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom + // Note that the below logic depends on the m.room.create event being the + // first event that is persisted to the database when creating or joining a + // room. + var roomVersion gomatrixserverlib.RoomVersion + if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { + return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) + } + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) + if err != nil { + return err + } + return nil + }) + return roomNID, err } -func (d *Database) storeEvent( - ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - redactionEvent *gomatrixserverlib.Event - redactedEventID string - err error - ) - var txn *sql.Tx - if updater != nil && updater.txn != nil { - txn = updater.txn - } - // First writer is with a database-provided transaction, so that NIDs are assigned - // globally outside of the updater context, to help avoid races. +func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) - } - - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return fmt.Errorf("d.assignRoomNID: %w", err) - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { return fmt.Errorf("d.assignEventTypeNID: %w", err) } + return nil + }) + return eventTypeNID, err +} - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return fmt.Errorf("d.assignStateKeyNID: %w", err) - } +func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (eventStateKeyNID types.EventStateKeyNID, err error) { + if eventStateKey == nil { + return 0, nil + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return fmt.Errorf("d.assignStateKeyNID: %w", err) } - return nil }) if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, err } + + return eventStateKeyNID, nil +} + +func (d *Database) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + var ( + eventNID types.EventNID + stateNID types.StateSnapshotNID + redactionEvent *gomatrixserverlib.Event + redactedEventID string + err error + ) // Second writer is using the database-provided transaction, probably from the // room updater, for easy roll-back if required. - err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventNID, stateNID, err = d.EventsTable.InsertEvent( ctx, txn, @@ -711,7 +752,7 @@ func (d *Database) storeEvent( return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } if !isRejected { // ignore rejected redaction events - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event) if err != nil { return fmt.Errorf("d.handleRedactions: %w", err) } @@ -719,7 +760,7 @@ func (d *Database) storeEvent( return nil }) if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } // We should attempt to update the previous events table with any @@ -734,28 +775,28 @@ func (d *Database) storeEvent( // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. succeeded := false - if updater == nil { - var roomInfo *types.RoomInfo - roomInfo, err = d.roomInfo(ctx, txn, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } - updater, err = d.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) - } - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + var roomInfo *types.RoomInfo + roomInfo, err = d.roomInfo(ctx, nil, event.RoomID()) + if err != nil { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } + if roomInfo == nil && len(prevEvents) > 0 { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + var updater *RoomUpdater + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) } succeeded = true } - return eventNID, roomNID, types.StateAtEvent{ + return eventNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, StateEntry: types.StateEntry{ StateKeyTuple: types.StateKeyTuple{ @@ -807,6 +848,10 @@ func (d *Database) MissingAuthPrevEvents( func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID) + if ok { + return roomNID, nil + } // Check if we already have a numeric ID in the database. roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { @@ -817,12 +862,20 @@ func (d *Database) assignRoomNID( roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) } } - return roomNID, err + if err != nil { + return 0, err + } + d.Cache.StoreRoomServerRoomID(roomNID, roomID) + return roomNID, nil } func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { + eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType) + if ok { + return eventTypeNID, nil + } // Check if we already have a numeric ID in the database. eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { @@ -833,12 +886,20 @@ func (d *Database) assignEventTypeNID( eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) } } - return eventTypeNID, err + if err != nil { + return 0, err + } + d.Cache.StoreEventTypeKey(eventTypeNID, eventType) + return eventTypeNID, nil } func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { + eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) + if ok { + return eventStateKeyNID, nil + } // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { @@ -849,6 +910,7 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } + d.Cache.StoreEventStateKey(eventStateKeyNID, eventStateKey) return eventStateKeyNID, err } @@ -892,7 +954,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( // // Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. func (d *Database) handleRedactions( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*gomatrixserverlib.Event, string, error) { var err error isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil @@ -912,7 +974,7 @@ func (d *Database) handleRedactions( } } - redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event) + redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event) if err != nil { return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err) } @@ -978,7 +1040,7 @@ func (d *Database) handleRedactions( // loadRedactionPair returns both the redaction event and the redacted event, else nil. func (d *Database) loadRedactionPair( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo @@ -1010,9 +1072,9 @@ func (d *Database) loadRedactionPair( } if isRedactionEvent { - redactedEvent = d.loadEvent(ctx, info.RedactsEventID) + redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID) } else { - redactionEvent = d.loadEvent(ctx, info.RedactionEventID) + redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID) } return redactionEvent, redactedEvent, info.Validated, nil @@ -1028,7 +1090,7 @@ func (d *Database) applyRedactions(events []types.Event) { } // loadEvent loads a single event or returns nil on any problems/missing event -func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { +func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event { nids, err := d.EventNIDs(ctx, []string{eventID}) if err != nil { return nil @@ -1036,7 +1098,7 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, []types.EventNID{nids[eventID]}) + evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } @@ -1445,16 +1507,37 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +// PurgeRoom removes all information about a given room from the roomserver. +// For large rooms this operation may take a considerable amount of time. +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("room %s does not exist", roomID) + } + return fmt.Errorf("failed to lock the room: %w", err) + } + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) + }) +} + func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // un-publish old room - if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { - return fmt.Errorf("failed to unpublish room: %w", err) + published, err := d.PublishedTable.SelectPublishedFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get published room: %w", err) } - // publish new room - if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { - return fmt.Errorf("failed to publish room: %w", err) + if published { + // un-publish old room + if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } } // Migrate any existing room aliases diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 58724340c..3acb55a3a 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -3,7 +3,9 @@ package shared_test import ( "context" "testing" + "time" + "github.com/matrix-org/dendrite/internal/caching" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -48,11 +50,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat } assert.NoError(t, err) + cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) + return &shared.Database{ DB: db, EventStateKeysTable: stateKeyTable, MembershipTable: membershipTable, Writer: sqlutil.NewExclusiveWriter(), + Cache: cache, }, func() { err := base.Close() assert.NoError(t, err) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index f39b9902d..aacf4bc9a 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -110,10 +110,10 @@ const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" const bulkSelectEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id IN ($1)" const bulkSelectUnsentEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -572,20 +572,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, false) } // BulkSelectEventNIDs returns a map from string event ID to numeric event ID // only for events that haven't already been sent to the roomserver output. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, true) } // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { @@ -609,14 +609,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") - results := make(map[string]types.EventNID, len(eventIDs)) + results := make(map[string]types.EventMetadata, len(eventIDs)) var eventID string var eventNID int64 + var roomNID int64 for rows.Next() { - if err = rows.Scan(&eventID, &eventNID); err != nil { + if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil { return nil, err } - results[eventID] = types.EventNID(eventNID) + results[eventID] = types.EventMetadata{ + EventNID: types.EventNID(eventNID), + RoomNID: types.RoomNID(roomNID), + } } return results, nil } diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go new file mode 100644 index 000000000..c7b4d27a5 --- /dev/null +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid IN (" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids IN(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt + stateSnapshot *stateSnapshotStatements +} + +func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) { + s := &purgeStatements{stateSnapshot: stateSnapshot} + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + //{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} + +func (s *purgeStatements) purgeStateBlocks( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) error { + // Get all stateBlockNIDs + stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } + params := make([]interface{}, len(stateBlockNIDs)) + seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs)) + // dedupe NIDs + for k, v := range stateBlockNIDs { + if _, ok := seenNIDs[v]; ok { + continue + } + params[k] = v + seenNIDs[v] = struct{}{} + } + + query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" + return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..7556b3461 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" + const bulkSelectRoomNIDsSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 4e67d4da1..ae8181cfa 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error { return err } -func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) { s := &stateBlockStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 73827522c..e57e1a4bf 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -62,10 +63,14 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +const selectStateBlockNIDsForRoomNID = "" + + "SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1" + type stateSnapshotStatements struct { db *sql.DB insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt + selectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -73,7 +78,7 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{ db: db, } @@ -81,6 +86,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID}, }.Prepare(db) } @@ -146,3 +152,33 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ) ([]types.EventNID, error) { return nil, tables.OptimisationNotSupportedError } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return nil, tables.OptimisationNotSupportedError +} + +func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.StateBlockNID, error) { + var res []types.StateBlockNID + rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID) + if err != nil { + return res, nil + } + defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed") + + var stateBlockNIDs []types.StateBlockNID + var stateBlockNIDsJSON string + for rows.Next() { + if err = rows.Scan(&stateBlockNIDsJSON); err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil { + return nil, err + } + + res = append(res, stateBlockNIDs...) + } + + return res, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 01c3f879c..392edd289 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -197,6 +197,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db, stateSnapshot) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, Cache: cache, @@ -215,6 +220,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, RedactionsTable: redactions, GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 80fcf72dd..4ce2a9c4e 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -63,8 +63,8 @@ type Events interface { BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) - BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) + BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error) @@ -73,6 +73,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error @@ -90,6 +91,10 @@ type StateSnapshot interface { // which users are in a room faster than having to load the entire room state. In the // case of SQLite, this will return tables.OptimisationNotSupportedError. BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error) + + BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) } type StateBlock interface { @@ -173,6 +178,12 @@ type Redactions interface { MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } +type Purge interface { + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go index b2e59377d..c7c991b20 100644 --- a/roomserver/storage/tables/state_snapshot_table_test.go +++ b/roomserver/storage/tables/state_snapshot_table_test.go @@ -29,6 +29,8 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables. assert.NoError(t, err) err = postgres.CreateEventsTable(db) assert.NoError(t, err) + err = postgres.CreateEventJSONTable(db) + assert.NoError(t, err) err = postgres.CreateStateBlockTable(db) assert.NoError(t, err) // ... and then the snapshot table itself diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f40980994..6401a94be 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -38,6 +38,11 @@ type EventNID int64 // RoomNID is a numeric ID for a room. type RoomNID int64 +type EventMetadata struct { + EventNID EventNID + RoomNID RoomNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64 diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 729d00a80..c40d8e0f7 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV9 + return gomatrixserverlib.RoomVersionV10 } // RoomVersions returns a map of all known room versions to this diff --git a/setup/base/base.go b/setup/base/base.go index d3adbf53f..aabdd7937 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -15,11 +15,13 @@ package base import ( + "bytes" "context" - "crypto/tls" "database/sql" + "embed" "encoding/json" "fmt" + "html/template" "io" "net" "net/http" @@ -35,8 +37,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" @@ -50,21 +50,14 @@ import ( "github.com/sirupsen/logrus" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" - federationAPI "github.com/matrix-org/dendrite/federationapi/api" - federationIntHTTP "github.com/matrix-org/dendrite/federationapi/inthttp" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" - keyinthttp "github.com/matrix-org/dendrite/keyserver/inthttp" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - userapi "github.com/matrix-org/dendrite/userapi/api" - userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" ) +//go:embed static/*.gotmpl +var staticContent embed.FS + // BaseDendrite is a base for creating new instances of dendrite. It parses // command line flags and config, and exposes methods for creating various // resources. All errors are handled by logging then exiting, so all methods @@ -72,19 +65,16 @@ import ( // Must be closed when shutting down. type BaseDendrite struct { *process.ProcessContext - componentName string tracerCloser io.Closer PublicClientAPIMux *mux.Router PublicFederationAPIMux *mux.Router PublicKeyAPIMux *mux.Router PublicMediaAPIMux *mux.Router PublicWellKnownAPIMux *mux.Router - InternalAPIMux *mux.Router + PublicStaticMux *mux.Router DendriteAdminMux *mux.Router SynapseAdminMux *mux.Router NATS *jetstream.NATSInstance - UseHTTPAPIs bool - apiHttpClient *http.Client Cfg *config.Dendrite Caches *caching.Caches DNSCache *gomatrixserverlib.DNSCache @@ -98,38 +88,26 @@ type BaseDendrite struct { const NoListener = "" const HTTPServerTimeout = time.Minute * 5 -const HTTPClientTimeout = time.Second * 30 type BaseDendriteOptions int const ( DisableMetrics BaseDendriteOptions = iota - UseHTTPAPIs - PolylithMode ) // NewBaseDendrite creates a new instance to be used by a component. -// The componentName is used for logging purposes, and should be a friendly name -// of the compontent running, e.g. "SyncAPI" -func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...BaseDendriteOptions) *BaseDendrite { +func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *BaseDendrite { platformSanityChecks() - useHTTPAPIs := false enableMetrics := true - isMonolith := true for _, opt := range options { switch opt { case DisableMetrics: enableMetrics = false - case UseHTTPAPIs: - useHTTPAPIs = true - case PolylithMode: - isMonolith = false - useHTTPAPIs = true } } configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors, isMonolith) + cfg.Verify(configErrors) if len(*configErrors) > 0 { for _, err := range *configErrors { logrus.Errorf("Configuration error: %s", err) @@ -138,7 +116,7 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base } internal.SetupStdLogging() - internal.SetupHookLogging(cfg.Logging, componentName) + internal.SetupHookLogging(cfg.Logging) internal.SetupPprof() logrus.Infof("Dendrite version %s", internal.VersionString()) @@ -147,14 +125,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base logrus.Warn("Open registration is enabled") } - closer, err := cfg.SetupTracing("Dendrite" + componentName) + closer, err := cfg.SetupTracing() if err != nil { logrus.WithError(err).Panicf("failed to start opentracing") } var fts *fulltext.Search - isSyncOrMonolith := componentName == "syncapi" || isMonolith - if cfg.SyncAPI.Fulltext.Enabled && isSyncOrMonolith { + if cfg.SyncAPI.Fulltext.Enabled { fts, err = fulltext.New(cfg.SyncAPI.Fulltext) if err != nil { logrus.WithError(err).Panicf("failed to create full text") @@ -189,32 +166,12 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base ) } - apiClient := http.Client{ - Timeout: time.Minute * 10, - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // Ordinarily HTTP/2 would expect TLS, but the remote listener is - // H2C-enabled (HTTP/2 without encryption). Overriding the DialTLS - // function with a plain Dial allows us to trick the HTTP client - // into establishing a HTTP/2 connection without TLS. - // TODO: Eventually we will want to look at authenticating and - // encrypting these internal HTTP APIs, at which point we will have - // to reconsider H2C and change all this anyway. - return net.Dial(network, addr) - }, - }, - } - // If we're in monolith mode, we'll set up a global pool of database // connections. A component is welcome to use this pool if they don't // have a separate database config of their own. var db *sql.DB var writer sqlutil.Writer if cfg.Global.DatabaseOptions.ConnectionString != "" { - if !isMonolith { - logrus.Panic("Using a global database connection pool is not supported in polylith deployments") - } if cfg.Global.DatabaseOptions.ConnectionString.IsSQLite() { logrus.Panic("Using a global database connection pool is not supported with SQLite databases") } @@ -239,8 +196,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base return &BaseDendrite{ ProcessContext: process.NewProcessContext(), - componentName: componentName, - UseHTTPAPIs: useHTTPAPIs, tracerCloser: closer, Cfg: cfg, Caches: caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, enableMetrics), @@ -250,11 +205,10 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base PublicKeyAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicKeyPathPrefix).Subrouter().UseEncodedPath(), PublicMediaAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicMediaPathPrefix).Subrouter().UseEncodedPath(), PublicWellKnownAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicWellKnownPrefix).Subrouter().UseEncodedPath(), - InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), + PublicStaticMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicStaticPath).Subrouter().UseEncodedPath(), DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), NATS: &jetstream.NATSInstance{}, - apiHttpClient: &apiClient, Database: db, // set if monolith with global connection pool only DatabaseWriter: writer, // set if monolith with global connection pool only EnableMetrics: enableMetrics, @@ -264,6 +218,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base // Close implements io.Closer func (b *BaseDendrite) Close() error { + b.ProcessContext.ShutdownDendrite() + b.ProcessContext.WaitForShutdown() return b.tracerCloser.Close() } @@ -290,52 +246,6 @@ func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, return nil, nil, fmt.Errorf("no database connections configured") } -// AppserviceHTTPClient returns the AppServiceInternalAPI for hitting the appservice component over HTTP. -func (b *BaseDendrite) AppserviceHTTPClient() appserviceAPI.AppServiceInternalAPI { - a, err := asinthttp.NewAppserviceClient(b.Cfg.AppServiceURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("CreateHTTPAppServiceAPIs failed") - } - return a -} - -// RoomserverHTTPClient returns RoomserverInternalAPI for hitting the roomserver over HTTP. -func (b *BaseDendrite) RoomserverHTTPClient() roomserverAPI.RoomserverInternalAPI { - rsAPI, err := rsinthttp.NewRoomserverClient(b.Cfg.RoomServerURL(), b.apiHttpClient, b.Caches) - if err != nil { - logrus.WithError(err).Panic("RoomserverHTTPClient failed", b.apiHttpClient) - } - return rsAPI -} - -// UserAPIClient returns UserInternalAPI for hitting the userapi over HTTP. -func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { - userAPI, err := userapiinthttp.NewUserAPIClient(b.Cfg.UserAPIURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("UserAPIClient failed", b.apiHttpClient) - } - return userAPI -} - -// FederationAPIHTTPClient returns FederationInternalAPI for hitting -// the federation API server over HTTP -func (b *BaseDendrite) FederationAPIHTTPClient() federationAPI.FederationInternalAPI { - f, err := federationIntHTTP.NewFederationAPIClient(b.Cfg.FederationAPIURL(), b.apiHttpClient, b.Caches) - if err != nil { - logrus.WithError(err).Panic("FederationAPIHTTPClient failed", b.apiHttpClient) - } - return f -} - -// KeyServerHTTPClient returns KeyInternalAPI for hitting the key server over HTTP -func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { - f, err := keyinthttp.NewKeyServerClient(b.Cfg.KeyServerURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("KeyServerHTTPClient failed", b.apiHttpClient) - } - return f -} - // PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) @@ -403,6 +313,7 @@ func (b *BaseDendrite) configureHTTPErrors() { for _, router := range []*mux.Router{ b.PublicMediaAPIMux, b.DendriteAdminMux, b.SynapseAdminMux, b.PublicWellKnownAPIMux, + b.PublicStaticMux, } { router.NotFoundHandler = notFoundCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler @@ -431,20 +342,18 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() { }) } -// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on -// ApiMux under /api/ and adds a prometheus handler under /metrics. +// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs +// and adds a prometheus handler under /_dendrite/metrics. func (b *BaseDendrite) SetupAndServeHTTP( - internalHTTPAddr, externalHTTPAddr config.HTTPAddress, + externalHTTPAddr config.HTTPAddress, certFile, keyFile *string, ) { // Manually unlocked right before actually serving requests, // as we don't return from this method (defer doesn't work). b.startupLock.Lock() - internalAddr, _ := internalHTTPAddr.Address() externalAddr, _ := externalHTTPAddr.Address() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - internalRouter := externalRouter externalServ := &http.Server{ Addr: string(externalAddr), @@ -454,35 +363,33 @@ func (b *BaseDendrite) SetupAndServeHTTP( return b.ProcessContext.Context() }, } - internalServ := externalServ - - if internalAddr != NoListener && externalAddr != internalAddr { - // H2C allows us to accept HTTP/2 connections without TLS - // encryption. Since we don't currently require any form of - // authentication or encryption on these internal HTTP APIs, - // H2C gives us all of the advantages of HTTP/2 (such as - // stream multiplexing and avoiding head-of-line blocking) - // without enabling TLS. - internalH2S := &http2.Server{} - internalRouter = mux.NewRouter().SkipClean(true).UseEncodedPath() - internalServ = &http.Server{ - Addr: string(internalAddr), - Handler: h2c.NewHandler(internalRouter, internalH2S), - BaseContext: func(_ net.Listener) context.Context { - return b.ProcessContext.Context() - }, - } - } b.configureHTTPErrors() - internalRouter.PathPrefix(httputil.InternalPathPrefix).Handler(b.InternalAPIMux) + //Redirect for Landing Page + externalRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, httputil.PublicStaticPath, http.StatusFound) + }) + if b.Cfg.Global.Metrics.Enabled { - internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) + externalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) } b.ConfigureAdminEndpoints() + // Parse and execute the landing page template + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + landingPage := &bytes.Buffer{} + if err := tmpl.ExecuteTemplate(landingPage, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }); err != nil { + logrus.WithError(err).Fatal("failed to execute landing page template") + } + + b.PublicStaticMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(landingPage.Bytes()) + }) + var clientHandler http.Handler clientHandler = b.PublicClientAPIMux if b.Cfg.Global.Sentry.Enabled { @@ -499,7 +406,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( }) federationHandler = sentryHandler.Handle(b.PublicFederationAPIMux) } - internalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(b.DendriteAdminMux) + externalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(b.DendriteAdminMux) externalRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(clientHandler) if !b.Cfg.Global.DisableFederation { externalRouter.PathPrefix(httputil.PublicKeyPathPrefix).Handler(b.PublicKeyAPIMux) @@ -508,40 +415,14 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(b.SynapseAdminMux) externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) + externalRouter.PathPrefix(httputil.PublicStaticPath).Handler(b.PublicStaticMux) b.startupLock.Unlock() - if internalAddr != NoListener && internalAddr != externalAddr { - go func() { - var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once - logrus.Infof("Starting internal %s listener on %s", b.componentName, internalServ.Addr) - b.ProcessContext.ComponentStarted() - internalServ.RegisterOnShutdown(func() { - if internalShutdown.CompareAndSwap(false, true) { - b.ProcessContext.ComponentFinished() - logrus.Infof("Stopped internal HTTP listener") - } - }) - if certFile != nil && keyFile != nil { - if err := internalServ.ListenAndServeTLS(*certFile, *keyFile); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTPS") - } - } - } else { - if err := internalServ.ListenAndServe(); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTP") - } - } - } - logrus.Infof("Stopped internal %s listener on %s", b.componentName, internalServ.Addr) - }() - } if externalAddr != NoListener { go func() { var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once - logrus.Infof("Starting external %s listener on %s", b.componentName, externalServ.Addr) + logrus.Infof("Starting external listener on %s", externalServ.Addr) b.ProcessContext.ComponentStarted() externalServ.RegisterOnShutdown(func() { if externalShutdown.CompareAndSwap(false, true) { @@ -562,7 +443,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( } } } - logrus.Infof("Stopped external %s listener on %s", b.componentName, externalServ.Addr) + logrus.Infof("Stopped external listener on %s", externalServ.Addr) }() } @@ -570,7 +451,6 @@ func (b *BaseDendrite) SetupAndServeHTTP( <-b.ProcessContext.WaitForShutdown() logrus.Infof("Stopping HTTP listeners") - _ = internalServ.Shutdown(context.Background()) _ = externalServ.Shutdown(context.Background()) logrus.Infof("Stopped HTTP listeners") } @@ -593,6 +473,12 @@ func (b *BaseDendrite) WaitForShutdown() { logrus.Warnf("failed to flush all Sentry events!") } } + if b.Fulltext != nil { + err := b.Fulltext.Close() + if err != nil { + logrus.Warnf("failed to close full text search!") + } + } logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/base/base_test.go b/setup/base/base_test.go new file mode 100644 index 000000000..d906294c0 --- /dev/null +++ b/setup/base/base_test.go @@ -0,0 +1,57 @@ +package base_test + +import ( + "bytes" + "embed" + "html/template" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +//go:embed static/*.gotmpl +var staticContent embed.FS + +func TestLandingPage(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + // hack: create a server and close it immediately, just to get a random port assigned + s := httptest.NewServer(nil) + s.Close() + + // start base with the listener and wait for it to be started + go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) + time.Sleep(time.Millisecond * 10) + + // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page + req, err := http.NewRequest(http.MethodGet, s.URL, nil) + assert.NoError(t, err) + + // do the request + resp, err := s.Client().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} diff --git a/setup/base/static/index.gotmpl b/setup/base/static/index.gotmpl new file mode 100644 index 000000000..b3c5576eb --- /dev/null +++ b/setup/base/static/index.gotmpl @@ -0,0 +1,63 @@ + + + + Dendrite is running + + + + +

It works! Dendrite {{ .Version }} is running

+

Your Dendrite server is listening on this port and is ready for messages.

+

To use this server you'll need a Matrix client. +

+

Welcome to the Matrix universe :)

+
+

+ + + matrix.org + + +

+ + diff --git a/setup/config/config.go b/setup/config/config.go index 41d2b6674..848766162 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -62,6 +62,7 @@ type Dendrite struct { RoomServer RoomServer `yaml:"room_server"` SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + RelayAPI RelayAPI `yaml:"relay_api"` MSCs MSCs `yaml:"mscs"` @@ -78,8 +79,6 @@ type Dendrite struct { // Any information derived from the configuration options for later use. Derived Derived `yaml:"-"` - - IsMonolith bool `yaml:"-"` } // TODO: Kill Derived @@ -113,15 +112,6 @@ type Derived struct { // servers from creating RoomIDs in exclusive application service namespaces } -type InternalAPIOptions struct { - Listen HTTPAddress `yaml:"listen"` - Connect HTTPAddress `yaml:"connect"` -} - -type ExternalAPIOptions struct { - Listen HTTPAddress `yaml:"listen"` -} - // A Path on the filesystem. type Path string @@ -190,7 +180,7 @@ type ConfigErrors []string // Load a yaml config file for a server run as multiple processes or as a monolith. // Checks the config to ensure that it is valid. -func Load(configPath string, monolith bool) (*Dendrite, error) { +func Load(configPath string) (*Dendrite, error) { configData, err := os.ReadFile(configPath) if err != nil { return nil, err @@ -201,28 +191,26 @@ func Load(configPath string, monolith bool) (*Dendrite, error) { } // Pass the current working directory and os.ReadFile so that they can // be mocked in the tests - return loadConfig(basePath, configData, os.ReadFile, monolith) + return loadConfig(basePath, configData, os.ReadFile) } func loadConfig( basePath string, configData []byte, readFile func(string) ([]byte, error), - monolithic bool, ) (*Dendrite, error) { var c Dendrite c.Defaults(DefaultOpts{ - Generate: false, - Monolithic: monolithic, + Generate: false, + SingleDatabase: true, }) - c.IsMonolith = monolithic var err error if err = yaml.Unmarshal(configData, &c); err != nil { return nil, err } - if err = c.check(monolithic); err != nil { + if err = c.check(); err != nil { return nil, err } @@ -332,8 +320,8 @@ func (config *Dendrite) Derive() error { } type DefaultOpts struct { - Generate bool - Monolithic bool + Generate bool + SingleDatabase bool } // SetDefaults sets default config values if they are not explicitly set. @@ -349,21 +337,22 @@ func (c *Dendrite) Defaults(opts DefaultOpts) { c.SyncAPI.Defaults(opts) c.UserAPI.Defaults(opts) c.AppServiceAPI.Defaults(opts) + c.RelayAPI.Defaults(opts) c.MSCs.Defaults(opts) c.Wiring() } -func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Dendrite) Verify(configErrs *ConfigErrors) { type verifiable interface { - Verify(configErrs *ConfigErrors, isMonolith bool) + Verify(configErrs *ConfigErrors) } for _, c := range []verifiable{ &c.Global, &c.ClientAPI, &c.FederationAPI, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, &c.MSCs, + &c.AppServiceAPI, &c.RelayAPI, &c.MSCs, } { - c.Verify(configErrs, isMonolith) + c.Verify(configErrs) } } @@ -377,6 +366,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.RelayAPI.Matrix = &c.Global c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived @@ -412,14 +402,6 @@ func checkNotEmpty(configErrs *ConfigErrors, key, value string) { } } -// checkNotZero verifies the given value is not zero in the configuration. -// If it is, adds an error to the list. -func checkNotZero(configErrs *ConfigErrors, key string, value int64) { - if value == 0 { - configErrs.Add(fmt.Sprintf("missing config key %q", key)) - } -} - // checkPositive verifies the given value is positive (zero included) // in the configuration. If it is not, adds an error to the list. func checkPositive(configErrs *ConfigErrors, key string, value int64) { @@ -428,26 +410,6 @@ func checkPositive(configErrs *ConfigErrors, key string, value int64) { } } -// checkURL verifies that the parameter is a valid URL -func checkURL(configErrs *ConfigErrors, key, value string) { - if value == "" { - configErrs.Add(fmt.Sprintf("missing config key %q", key)) - return - } - url, err := url.Parse(value) - if err != nil { - configErrs.Add(fmt.Sprintf("config key %q contains invalid URL (%s)", key, err.Error())) - return - } - switch url.Scheme { - case "http": - case "https": - default: - configErrs.Add(fmt.Sprintf("config key %q URL should be http:// or https://", key)) - return - } -} - // checkLogging verifies the parameters logging.* are valid. func (config *Dendrite) checkLogging(configErrs *ConfigErrors) { for _, logrusHook := range config.Logging { @@ -458,7 +420,7 @@ func (config *Dendrite) checkLogging(configErrs *ConfigErrors) { // check returns an error type containing all errors found within the config // file. -func (config *Dendrite) check(_ bool) error { // monolithic +func (config *Dendrite) check() error { // monolithic var configErrs ConfigErrors if config.Version != Version { @@ -525,58 +487,13 @@ func readKeyPEM(path string, data []byte, enforceKeyIDFormat bool) (gomatrixserv } } -// AppServiceURL returns a HTTP URL for where the appservice component is listening. -func (config *Dendrite) AppServiceURL() string { - // Hard code the appservice server to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.AppServiceAPI.InternalAPI.Connect) -} - -// FederationAPIURL returns an HTTP URL for where the federation API is listening. -func (config *Dendrite) FederationAPIURL() string { - // Hard code the federationapi to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.FederationAPI.InternalAPI.Connect) -} - -// RoomServerURL returns an HTTP URL for where the roomserver is listening. -func (config *Dendrite) RoomServerURL() string { - // Hard code the roomserver to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.RoomServer.InternalAPI.Connect) -} - -// UserAPIURL returns an HTTP URL for where the userapi is listening. -func (config *Dendrite) UserAPIURL() string { - // Hard code the userapi to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.UserAPI.InternalAPI.Connect) -} - -// KeyServerURL returns an HTTP URL for where the key server is listening. -func (config *Dendrite) KeyServerURL() string { - // Hard code the key server to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.KeyServer.InternalAPI.Connect) -} - // SetupTracing configures the opentracing using the supplied configuration. -func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) { +func (config *Dendrite) SetupTracing() (closer io.Closer, err error) { if !config.Tracing.Enabled { return io.NopCloser(bytes.NewReader([]byte{})), nil } return config.Tracing.Jaeger.InitGlobalTracer( - serviceName, + "Dendrite", jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}), jaegerconfig.Metrics(jaegermetrics.NullFactory), ) diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index bd21826fe..37e20a978 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -22,15 +22,13 @@ import ( "strings" log "github.com/sirupsen/logrus" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) type AppServiceAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - // DisableTLSValidation disables the validation of X.509 TLS certs // on appservice endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -39,18 +37,9 @@ type AppServiceAPI struct { } func (c *AppServiceAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7777" - c.InternalAPI.Connect = "http://localhost:7777" - } } -func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } - checkURL(configErrs, "app_service_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "app_service_api.internal_api.connect", string(c.InternalAPI.Connect)) +func (c *AppServiceAPI) Verify(configErrs *ConfigErrors) { } // ApplicationServiceNamespace is the namespace that a specific application diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 11628b1b0..b6c74a75f 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -9,9 +9,6 @@ type ClientAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // If set disables new users from registering (except via shared // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` @@ -58,11 +55,6 @@ type ClientAPI struct { } func (c *ClientAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7771" - c.InternalAPI.Connect = "http://localhost:7771" - c.ExternalAPI.Listen = "http://[::]:8071" - } c.RegistrationSharedSecret = "" c.RecaptchaPublicKey = "" c.RecaptchaPrivateKey = "" @@ -74,7 +66,7 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) { c.RateLimiting.Defaults() } -func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *ClientAPI) Verify(configErrs *ConfigErrors) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) if c.RecaptchaEnabled { @@ -85,10 +77,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" } if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha" + c.RecaptchaFormField = "g-recaptcha-response" } if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha-response" + c.RecaptchaSitekeyClass = "g-recaptcha" } checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) @@ -108,12 +100,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { ) } } - if isMonolith { // polylith required configs below - return - } - checkURL(configErrs, "client_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "client_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "client_api.external_api.listen", string(c.ExternalAPI.Listen)) } type TURN struct { diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 0f853865f..8c1540b57 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -1,13 +1,12 @@ package config -import "github.com/matrix-org/gomatrixserverlib" +import ( + "github.com/matrix-org/gomatrixserverlib" +) type FederationAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // The database stores information used by the federation destination queues to // send transactions to remote servers. Database DatabaseOptions `yaml:"database,omitempty"` @@ -18,6 +17,12 @@ type FederationAPI struct { // The default value is 16 if not specified, which is circa 18 hours. FederationMaxRetries uint32 `yaml:"send_max_retries"` + // P2P Feature: How many consecutive failures that we should tolerate when + // sending federation requests to a specific server until we should assume they + // are offline. If we assume they are offline then we will attempt to send + // messages to their relay server if we know of one that is appropriate. + P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"` + // FederationDisableTLSValidation disables the validation of X.509 TLS certs // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -36,13 +41,8 @@ type FederationAPI struct { } func (c *FederationAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7772" - c.InternalAPI.Connect = "http://localhost:7772" - c.ExternalAPI.Listen = "http://[::]:8072" - c.Database.Defaults(10) - } c.FederationMaxRetries = 16 + c.P2PFederationRetriesUntilAssumedOffline = 1 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { @@ -61,22 +61,16 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { }, }, } - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:federationapi.db" } } } -func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *FederationAPI) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "federation_api.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "federation_api.external_api.listen", string(c.ExternalAPI.Listen)) - checkURL(configErrs, "federation_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "federation_api.internal_api.connect", string(c.InternalAPI.Connect)) } // The config for setting a proxy to use for server->server requests diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 804eb1a2d..ed980afaa 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -38,7 +38,6 @@ type Global struct { // component does not specify any database options of its own, then this pool of // connections will be used instead. This way we don't have to manage connection // counts on a per-component basis, but can instead do it for the entire monolith. - // In a polylith deployment, this will be ignored. DatabaseOptions DatabaseOptions `yaml:"database,omitempty"` // The server name to delegate server-server communications to, with optional port @@ -93,7 +92,7 @@ func (c *Global) Defaults(opts DefaultOpts) { } } c.KeyValidityPeriod = time.Hour * 24 * 7 - if opts.Monolithic { + if opts.SingleDatabase { c.DatabaseOptions.Defaults(90) } c.JetStream.Defaults(opts) @@ -105,7 +104,7 @@ func (c *Global) Defaults(opts DefaultOpts) { c.Cache.Defaults() } -func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Global) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) @@ -113,13 +112,13 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { v.Verify(configErrs) } - c.JetStream.Verify(configErrs, isMonolith) - c.Metrics.Verify(configErrs, isMonolith) - c.Sentry.Verify(configErrs, isMonolith) - c.DNSCache.Verify(configErrs, isMonolith) - c.ServerNotices.Verify(configErrs, isMonolith) - c.ReportStats.Verify(configErrs, isMonolith) - c.Cache.Verify(configErrs, isMonolith) + c.JetStream.Verify(configErrs) + c.Metrics.Verify(configErrs) + c.Sentry.Verify(configErrs) + c.DNSCache.Verify(configErrs) + c.ServerNotices.Verify(configErrs) + c.ReportStats.Verify(configErrs) + c.Cache.Verify(configErrs) } func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { @@ -267,7 +266,7 @@ func (c *Metrics) Defaults(opts DefaultOpts) { } } -func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Metrics) Verify(configErrs *ConfigErrors) { } // ServerNotices defines the configuration used for sending server notices @@ -293,7 +292,7 @@ func (c *ServerNotices) Defaults(opts DefaultOpts) { } } -func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} +func (c *ServerNotices) Verify(errors *ConfigErrors) {} type Cache struct { EstimatedMaxSize DataUnit `yaml:"max_size_estimated"` @@ -305,7 +304,7 @@ func (c *Cache) Defaults() { c.MaxAge = time.Hour } -func (c *Cache) Verify(errors *ConfigErrors, isMonolith bool) { +func (c *Cache) Verify(errors *ConfigErrors) { checkPositive(errors, "max_size_estimated", int64(c.EstimatedMaxSize)) } @@ -323,7 +322,7 @@ func (c *ReportStats) Defaults() { c.Endpoint = "https://matrix.org/report-usage-stats/push" } -func (c *ReportStats) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *ReportStats) Verify(configErrs *ConfigErrors) { if c.Enabled { checkNotEmpty(configErrs, "global.report_stats.endpoint", c.Endpoint) } @@ -344,7 +343,7 @@ func (c *Sentry) Defaults() { c.Enabled = false } -func (c *Sentry) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Sentry) Verify(configErrs *ConfigErrors) { } type DatabaseOptions struct { @@ -364,8 +363,7 @@ func (c *DatabaseOptions) Defaults(conns int) { c.ConnMaxLifetimeSeconds = -1 } -func (c *DatabaseOptions) Verify(configErrs *ConfigErrors, isMonolith bool) { -} +func (c *DatabaseOptions) Verify(configErrs *ConfigErrors) {} // MaxIdleConns returns maximum idle connections to the DB func (c DatabaseOptions) MaxIdleConns() int { @@ -397,7 +395,7 @@ func (c *DNSCacheOptions) Defaults() { c.CacheLifetime = time.Minute * 5 } -func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors) { checkPositive(configErrs, "cache_size", int64(c.CacheSize)) checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime)) } diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index ef8bf014b..b8abed25c 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -41,11 +41,4 @@ func (c *JetStream) Defaults(opts DefaultOpts) { } } -func (c *JetStream) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } - // If we are running in a polylith deployment then we need at least - // one NATS JetStream server to talk to. - checkNotZero(configErrs, "global.jetstream.addresses", int64(len(c.Addresses))) -} +func (c *JetStream) Verify(configErrs *ConfigErrors) {} diff --git a/setup/config/config_keyserver.go b/setup/config/config_keyserver.go index dca9ca9f5..64710d957 100644 --- a/setup/config/config_keyserver.go +++ b/setup/config/config_keyserver.go @@ -3,31 +3,19 @@ package config type KeyServer struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` } func (c *KeyServer) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7779" - c.InternalAPI.Connect = "http://localhost:7779" - c.Database.Defaults(10) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:keyserver.db" } } } -func (c *KeyServer) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *KeyServer) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "key_server.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "key_server.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "key_server.internal_api.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/config/config_mediaapi.go b/setup/config/config_mediaapi.go index 53a8219eb..030bc3754 100644 --- a/setup/config/config_mediaapi.go +++ b/setup/config/config_mediaapi.go @@ -7,9 +7,6 @@ import ( type MediaAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // The MediaAPI database stores information about files uploaded and downloaded // by local users. It is only accessed by the MediaAPI. Database DatabaseOptions `yaml:"database,omitempty"` @@ -39,12 +36,6 @@ type MediaAPI struct { var DefaultMaxFileSizeBytes = FileSizeBytes(10485760) func (c *MediaAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7774" - c.InternalAPI.Connect = "http://localhost:7774" - c.ExternalAPI.Listen = "http://[::]:8074" - c.Database.Defaults(5) - } c.MaxFileSizeBytes = DefaultMaxFileSizeBytes c.MaxThumbnailGenerators = 10 if opts.Generate { @@ -65,14 +56,14 @@ func (c *MediaAPI) Defaults(opts DefaultOpts) { ResizeMethod: "scale", }, } - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:mediaapi.db" } c.BasePath = "./media_store" } } -func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *MediaAPI) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath)) checkPositive(configErrs, "media_api.max_file_size_bytes", int64(c.MaxFileSizeBytes)) checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators)) @@ -81,13 +72,8 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkPositive(configErrs, fmt.Sprintf("media_api.thumbnail_sizes[%d].width", i), int64(size.Width)) checkPositive(configErrs, fmt.Sprintf("media_api.thumbnail_sizes[%d].height", i), int64(size.Height)) } - if isMonolith { // polylith required configs below - return - } + if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "media_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "media_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "media_api.external_api.listen", string(c.ExternalAPI.Listen)) } diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 6d5ff39a5..21d4b4da0 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -14,11 +14,8 @@ type MSCs struct { } func (c *MSCs) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.Database.Defaults(5) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:mscs.db" } } @@ -34,10 +31,7 @@ func (c *MSCs) Enabled(msc string) bool { return false } -func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } diff --git a/setup/config/config_relayapi.go b/setup/config/config_relayapi.go new file mode 100644 index 000000000..ba7b78082 --- /dev/null +++ b/setup/config/config_relayapi.go @@ -0,0 +1,37 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +type RelayAPI struct { + Matrix *Global `yaml:"-"` + + // The database stores information used by the relay queue to + // forward transactions to remote servers. + Database DatabaseOptions `yaml:"database,omitempty"` +} + +func (c *RelayAPI) Defaults(opts DefaultOpts) { + if opts.Generate { + if !opts.SingleDatabase { + c.Database.ConnectionString = "file:relayapi.db" + } + } +} + +func (c *RelayAPI) Verify(configErrs *ConfigErrors) { + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString)) + } +} diff --git a/setup/config/config_roomserver.go b/setup/config/config_roomserver.go index 5e3b7f2ec..319c2419c 100644 --- a/setup/config/config_roomserver.go +++ b/setup/config/config_roomserver.go @@ -3,31 +3,19 @@ package config type RoomServer struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` } func (c *RoomServer) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7770" - c.InternalAPI.Connect = "http://localhost:7770" - c.Database.Defaults(20) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:roomserver.db" } } } -func (c *RoomServer) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *RoomServer) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "room_server.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "room_server.internal_ap.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index a87da3732..756f4cfb3 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -3,9 +3,6 @@ package config type SyncAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` RealIPHeader string `yaml:"real_ip_header"` @@ -14,31 +11,19 @@ type SyncAPI struct { } func (c *SyncAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7773" - c.InternalAPI.Connect = "http://localhost:7773" - c.ExternalAPI.Listen = "http://localhost:8073" - c.Database.Defaults(20) - } c.Fulltext.Defaults(opts) if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:syncapi.db" } } } -func (c *SyncAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - c.Fulltext.Verify(configErrs, isMonolith) - if isMonolith { // polylith required configs below - return - } +func (c *SyncAPI) Verify(configErrs *ConfigErrors) { + c.Fulltext.Verify(configErrs) if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "sync_api.database", string(c.Database.ConnectionString)) } - checkURL(configErrs, "sync_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "sync_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "sync_api.external_api.listen", string(c.ExternalAPI.Listen)) } type Fulltext struct { @@ -54,7 +39,7 @@ func (f *Fulltext) Defaults(opts DefaultOpts) { f.Language = "en" } -func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (f *Fulltext) Verify(configErrs *ConfigErrors) { if !f.Enabled { return } diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 3408bf46d..79407f30d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -20,20 +20,29 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) func TestLoadConfigRelative(t *testing.T) { - _, err := loadConfig("/my/config/dir", []byte(testConfig), + cfg, err := loadConfig("/my/config/dir", []byte(testConfig), mockReadFile{ "/my/config/dir/matrix_key.pem": testKey, "/my/config/dir/tls_cert.pem": testCert, }.readFile, - false, ) if err != nil { t.Error("failed to load config:", err) } + + configErrors := &ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + t.Error("configuration verification failed") + } } const testConfig = ` @@ -68,10 +77,9 @@ global: display_name: "Server alerts" avatar: "" room_name: "Server Alerts" + jetstream: + addresses: ["test"] app_service_api: - internal_api: - listen: http://localhost:7777 - connect: http://localhost:7777 database: connection_string: file:appservice.db max_open_conns: 100 @@ -79,12 +87,7 @@ app_service_api: conn_max_lifetime: -1 config_files: [] client_api: - internal_api: - listen: http://localhost:7771 - connect: http://localhost:7771 - external_api: - listen: http://[::]:8071 - registration_disabled: false + registration_disabled: true registration_shared_secret: "" enable_registration_captcha: false recaptcha_public_key: "" @@ -97,36 +100,16 @@ client_api: turn_shared_secret: "" turn_username: "" turn_password: "" -current_state_server: - internal_api: - listen: http://localhost:7782 - connect: http://localhost:7782 - database: - connection_string: file:currentstate.db - max_open_conns: 100 - max_idle_conns: 2 - conn_max_lifetime: -1 federation_api: - internal_api: - listen: http://localhost:7772 - connect: http://localhost:7772 - external_api: - listen: http://[::]:8072 + database: + connection_string: file:federationapi.db key_server: - internal_api: - listen: http://localhost:7779 - connect: http://localhost:7779 database: connection_string: file:keyserver.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 media_api: - internal_api: - listen: http://localhost:7774 - connect: http://localhost:7774 - external_api: - listen: http://[::]:8074 database: connection_string: file:mediaapi.db max_open_conns: 100 @@ -147,18 +130,12 @@ media_api: height: 480 method: scale room_server: - internal_api: - listen: http://localhost:7770 - connect: http://localhost:7770 database: connection_string: file:roomserver.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 server_key_api: - internal_api: - listen: http://localhost:7780 - connect: http://localhost:7780 database: connection_string: file:serverkeyapi.db max_open_conns: 100 @@ -172,18 +149,12 @@ server_key_api: - key_id: ed25519:a_RXGa public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ sync_api: - internal_api: - listen: http://localhost:7773 - connect: http://localhost:7773 database: connection_string: file:syncapi.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 user_api: - internal_api: - listen: http://localhost:7781 - connect: http://localhost:7781 account_database: connection_string: file:userapi_accounts.db max_open_conns: 100 @@ -194,6 +165,12 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 +relay_api: + database: + connection_string: file:relayapi.db +mscs: + database: + connection_string: file:mscs.db tracing: enabled: false jaeger: diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index f8ad41d93..e64a3910c 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -5,8 +5,6 @@ import "golang.org/x/crypto/bcrypt" type UserAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - // The cost when hashing passwords. BCryptCost int `yaml:"bcrypt_cost"` @@ -28,28 +26,18 @@ type UserAPI struct { const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes func (c *UserAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7781" - c.InternalAPI.Connect = "http://localhost:7781" - c.AccountDatabase.Defaults(10) - } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" } } } -func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *UserAPI) Verify(configErrs *ConfigErrors) { checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) - if isMonolith { // polylith required configs below - return - } if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) } - checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/flags.go b/setup/flags.go index a9dac61a1..869caa280 100644 --- a/setup/flags.go +++ b/setup/flags.go @@ -43,7 +43,7 @@ func ParseFlags(monolith bool) *config.Dendrite { logrus.Fatal("--config must be supplied") } - cfg, err := config.Load(*configPath, monolith) + cfg, err := config.Load(*configPath) if err != nil { logrus.Fatalf("Invalid config file: %s", err) diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index c1ce9583f..533652160 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -77,6 +77,11 @@ func JetStreamConsumer( // The consumer was deleted so stop. return } else { + // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string + if err.Error() == "nats: Server Shutdown" { + logrus.WithContext(ctx).Warn("nats server shutting down") + return + } // Something else went wrong, so we'll panic. sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) diff --git a/setup/jetstream/log.go b/setup/jetstream/log.go new file mode 100644 index 000000000..880f7120b --- /dev/null +++ b/setup/jetstream/log.go @@ -0,0 +1,42 @@ +package jetstream + +import ( + "github.com/nats-io/nats-server/v2/server" + "github.com/sirupsen/logrus" +) + +var _ server.Logger = &LogAdapter{} + +type LogAdapter struct { + entry *logrus.Entry +} + +func NewLogAdapter() *LogAdapter { + return &LogAdapter{ + entry: logrus.StandardLogger().WithField("component", "jetstream"), + } +} + +func (l *LogAdapter) Noticef(format string, v ...interface{}) { + l.entry.Infof(format, v...) +} + +func (l *LogAdapter) Warnf(format string, v ...interface{}) { + l.entry.Warnf(format, v...) +} + +func (l *LogAdapter) Fatalf(format string, v ...interface{}) { + l.entry.Fatalf(format, v...) +} + +func (l *LogAdapter) Errorf(format string, v ...interface{}) { + l.entry.Errorf(format, v...) +} + +func (l *LogAdapter) Debugf(format string, v ...interface{}) { + l.entry.Debugf(format, v...) +} + +func (l *LogAdapter) Tracef(format string, v ...interface{}) { + l.entry.Tracef(format, v...) +} diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index af4eb2949..01fec9ad6 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -40,7 +40,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS } if s.Server == nil { var err error - s.Server, err = natsserver.NewServer(&natsserver.Options{ + opts := &natsserver.Options{ ServerName: "monolith", DontListen: true, JetStream: true, @@ -49,11 +49,12 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS MaxPayload: 16 * 1024 * 1024, NoSigs: true, NoLog: cfg.NoLog, - }) + } + s.Server, err = natsserver.NewServer(opts) if err != nil { panic(err) } - s.ConfigureLogger() + s.SetLogger(NewLogAdapter(), opts.Debug, opts.Trace) go func() { process.ComponentStarted() s.Start() diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..d8c652234 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -21,8 +21,9 @@ import ( "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -43,7 +44,7 @@ type Monolith struct { FederationAPI federationAPI.FederationInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI - KeyAPI keyAPI.KeyInternalAPI + RelayAPI relayAPI.RelayInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -58,17 +59,16 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { } clientapi.AddPublicRoutes( base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, + m.FederationAPI, m.UserAPI, userDirectoryProvider, m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( - base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, - m.KeyAPI, nil, - ) - mediaapi.AddPublicRoutes( - base, m.UserAPI, m.Client, - ) - syncapi.AddPublicRoutes( - base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, + base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, ) + mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client) + syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI) + + if m.RelayAPI != nil { + relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) + } } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 0388fcc53..f12fbbfcb 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -499,7 +499,7 @@ func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relT } type testUserAPI struct { - userapi.UserInternalAPITrace + userapi.UserInternalAPI accessTokens map[string]userapi.Device } @@ -516,7 +516,7 @@ func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace + roomserver.RoomserverInternalAPI userToJoinedRooms map[string][]string events map[string]*gomatrixserverlib.HeaderedEvent } @@ -548,8 +548,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve t.Helper() cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = "localhost" cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 92f081500..5faaefb8e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -19,7 +19,6 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -28,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1b67f5684..21838039a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms s.onRetirePeek(s.ctx, *output.RetirePeek) case api.OutputTypeRedactedEvent: err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + case api.OutputTypePurgeRoom: + err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) + if err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room + } default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -473,6 +480,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek( s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp}) } +func (s *OutputRoomEventConsumer) onPurgeRoom( + ctx context.Context, req api.OutputPurgeRoom, +) error { + logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API") + + if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API") + return err + } else { + logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API") + return nil + } +} + func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 356e83263..32208c585 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -25,7 +25,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -33,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. @@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct { durable string topic string db storage.Database - keyAPI keyapi.SyncKeyAPI + userAPI api.SyncKeyAPI isLocalServerName func(gomatrixserverlib.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier @@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, js nats.JetStreamContext, store storage.Database, - keyAPI keyapi.SyncKeyAPI, + userAPI api.SyncKeyAPI, notifier *notifier.Notifier, stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { @@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer( topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), db: store, - keyAPI: keyAPI, + userAPI: userAPI, isLocalServerName: cfg.Matrix.IsLocalServerName, notifier: notifier, stream: stream, @@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. - if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ + if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, }, &struct{}{}); err != nil { logger.WithError(err).Errorf("failed to mark as stale if needed") diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 71d7ddd15..ee695f0f5 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -33,8 +33,7 @@ func init() { } // calculateHistoryVisibilityDuration stores the time it takes to -// calculate the history visibility. In polylith mode the roundtrip -// to the roomserver is included in this time. +// calculate the history visibility. var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "dendrite", @@ -121,10 +120,7 @@ func ApplyHistoryVisibilityFilter( // Get the mapping from eventID -> eventVisibility eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) - visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) - if err != nil { - return eventsFiltered, err - } + visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) for _, ev := range events { evVis := visibilities[ev.EventID()] evVis.membershipCurrent = membershipCurrent @@ -175,7 +171,7 @@ func visibilityForEvents( rsAPI api.SyncRoomserverAPI, events []*gomatrixserverlib.HeaderedEvent, userID, roomID string, -) (map[string]eventVisibility, error) { +) map[string]eventVisibility { eventIDs := make([]string, len(events)) for i := range events { eventIDs[i] = events[i].EventID() @@ -185,6 +181,7 @@ func visibilityForEvents( // get the membership events for all eventIDs membershipResp := &api.QueryMembershipAtEventResponse{} + err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ RoomID: roomID, EventIDs: eventIDs, @@ -201,19 +198,20 @@ func visibilityForEvents( membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident visibility: event.Visibility, } - membershipEvs, ok := membershipResp.Memberships[eventID] - if !ok { + ev, ok := membershipResp.Membership[eventID] + if !ok || ev == nil { result[eventID] = vis continue } - for _, ev := range membershipEvs { - membership, err := ev.Membership() - if err != nil { - return result, err - } - vis.membershipAtEvent = membership + + membership, err := ev.Membership() + if err != nil { + result[eventID] = vis + continue } + vis.membershipAtEvent = membership + result[eventID] = vis } - return result, nil + return result } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 3d6b2a7f3..e7f677c85 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,22 +18,22 @@ import ( "context" "strings" + keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // DeviceOTKCounts adds one-time key counts to the /sync response -func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { - var queryRes keyapi.QueryOneTimeKeysResponse - _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ +func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + _ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. func DeviceListCatchup( - ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, + ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, res *types.Response, from, to types.StreamPosition, ) (newPos types.StreamPosition, hasNew bool, err error) { @@ -74,8 +74,8 @@ func DeviceListCatchup( if from > 0 { offset = int64(from) } - var queryRes keyapi.QueryKeyChangesResponse - _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + var queryRes api.QueryKeyChangesResponse + _ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) @@ -144,38 +144,42 @@ func TrackChangedUsers( // - Loop set of users and decrement by 1 for each user in newly left room. // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. var queryRes roomserverAPI.QuerySharedUsersResponse - err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, - IncludeRoomIDs: newlyLeftRooms, - }, &queryRes) - if err != nil { - return nil, nil, err - } var stateRes roomserverAPI.QueryBulkStateContentResponse - err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ - RoomIDs: newlyLeftRooms, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: "*", - }, - }, - AllowWildcards: true, - }, &stateRes) - if err != nil { - return nil, nil, err - } - for _, state := range stateRes.Rooms { - for tuple, membership := range state { - if membership != gomatrixserverlib.Join { - continue - } - queryRes.UserIDsToCount[tuple.StateKey]-- + if len(newlyLeftRooms) > 0 { + err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ + UserID: userID, + IncludeRoomIDs: newlyLeftRooms, + }, &queryRes) + if err != nil { + return nil, nil, err } - } - for userID, count := range queryRes.UserIDsToCount { - if count <= 0 { - left = append(left, userID) // left is returned + + err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ + RoomIDs: newlyLeftRooms, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: "*", + }, + }, + AllowWildcards: true, + }, &stateRes) + if err != nil { + return nil, nil, err + } + for _, state := range stateRes.Rooms { + for tuple, membership := range state { + if membership != gomatrixserverlib.Join { + continue + } + queryRes.UserIDsToCount[tuple.StateKey]-- + } + } + + for userID, count := range queryRes.UserIDsToCount { + if count <= 0 { + left = append(left, userID) // left is returned + } } } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 53f3e5a40..4bb851668 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -22,49 +21,49 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error { +func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error { return nil } -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error { return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { return nil } type mockRoomserverAPI struct { - api.RoomserverInternalAPITrace + api.RoomserverInternalAPI roomIDToJoinedMembers map[string][]string } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 095a868c7..76f003671 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -67,6 +67,8 @@ func Context( errMsg = "unable to parse filter" case *strconv.NumError: errMsg = "unable to parse limit" + default: + errMsg = err.Error() } return util.JSONResponse{ Code: http.StatusBadRequest, @@ -167,7 +169,18 @@ func Context( eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) - newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + + newState := state + if filter.LazyLoadMembers { + allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) + allEvents = append(allEvents, &requestedEvent) + evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll) + newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) + if err != nil { + logrus.WithError(err).Error("unable to load membership events") + return jsonerror.InternalServerError() + } + } ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ @@ -244,41 +257,43 @@ func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, star } func applyLazyLoadMembers( + ctx context.Context, device *userapi.Device, - filter *gomatrixserverlib.RoomEventFilter, - eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, - state []*gomatrixserverlib.HeaderedEvent, + snapshot storage.DatabaseTransaction, + roomID string, + events []gomatrixserverlib.ClientEvent, lazyLoadCache caching.LazyLoadCache, -) []*gomatrixserverlib.HeaderedEvent { - if filter == nil || !filter.LazyLoadMembers { - return state - } - allEvents := append(eventsBefore, eventsAfter...) - x := make(map[string]struct{}) +) ([]*gomatrixserverlib.HeaderedEvent, error) { + eventSenders := make(map[string]struct{}) // get members who actually send an event - for _, e := range allEvents { + for _, e := range events { // Don't add membership events the client should already know about if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached { continue } - x[e.Sender] = struct{}{} + eventSenders[e.Sender] = struct{}{} } - newState := []*gomatrixserverlib.HeaderedEvent{} - membershipEvents := []*gomatrixserverlib.HeaderedEvent{} - for _, event := range state { - if event.Type() != gomatrixserverlib.MRoomMember { - newState = append(newState, event) - } else { - // did the user send an event? - if _, ok := x[event.Sender()]; ok { - membershipEvents = append(membershipEvents, event) - lazyLoadCache.StoreLazyLoadedUser(device, event.RoomID(), event.Sender(), event.EventID()) - } - } + wantUsers := make([]string, 0, len(eventSenders)) + for userID := range eventSenders { + wantUsers = append(wantUsers, userID) } - // Add the membershipEvents to the end of the list, to make Sytest happy - return append(newState, membershipEvents...) + + // Query missing membership events + filter := gomatrixserverlib.DefaultStateFilter() + filter.Senders = &wantUsers + filter.Types = &[]string{gomatrixserverlib.MRoomMember} + memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) + if err != nil { + return nil, err + } + + // cache the membership events + for _, membership := range memberships { + lazyLoadCache.StoreLazyLoadedUser(device, roomID, *membership.StateKey(), membership.EventID()) + } + + return memberships, nil } func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 3fcc3235c..9ffdf513f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -16,16 +16,16 @@ package routing import ( "encoding/json" + "math" "net/http" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type getMembershipResponse struct { @@ -87,19 +87,18 @@ func GetMemberships( if err != nil { return jsonerror.InternalServerError() } + defer db.Rollback() // nolint: errcheck atToken, err := types.NewTopologyTokenFromString(at) if err != nil { + atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} if queryRes.HasBeenInRoom && !queryRes.IsInRoom { // If you have left the room then this will be the members of the room when you left. atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) - } else { - // If you are joined to the room then this will be the current members of the room. - atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) - } - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 0d740ebfc..02d8fcc7e 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -17,6 +17,7 @@ package routing import ( "context" "fmt" + "math" "net/http" "sort" "time" @@ -57,12 +58,13 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end,omitempty"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +// nolint:gocyclo func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, rsAPI api.SyncRoomserverAPI, @@ -177,10 +179,11 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") - return jsonerror.InternalServerError() + to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + if backwardOrdering { + // go 1 earlier than the first event so we correctly fetch the earliest event + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.TopologyToken{} } wasToProvided = false } @@ -244,7 +247,14 @@ func OnIncomingMessagesRequest( Start: start.String(), End: end.String(), } - res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) + if filter.LazyLoadMembers { + membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") + return jsonerror.InternalServerError() + } + res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...) + } // If we didn't return any events, set the end to an empty string, so it will be omitted // in the response JSON. @@ -263,40 +273,6 @@ func OnIncomingMessagesRequest( } } -// applyLazyLoadMembers loads membership events for users returned in Chunk, if the filter has -// LazyLoadMembers enabled. -func (m *messagesResp) applyLazyLoadMembers( - ctx context.Context, - db storage.DatabaseTransaction, - roomID string, - device *userapi.Device, - lazyLoad bool, - lazyLoadCache caching.LazyLoadCache, -) { - if !lazyLoad { - return - } - membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) - for _, evt := range m.Chunk { - // Don't add membership events the client should already know about - if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached { - continue - } - membership, err := db.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, evt.Sender) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to get membership event for user") - continue - } - if membership != nil { - membershipToUser[evt.Sender] = membership - lazyLoadCache.StoreLazyLoadedUser(device, roomID, evt.Sender, membership.EventID()) - } - } - for _, evt := range membershipToUser { - m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) - } -} - func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, @@ -577,24 +553,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] return events, nil } - -// setToDefault returns the default value for the "to" query parameter of a -// request to /messages if not provided. It defaults to either the earliest -// topological position (if we're going backward) or to the latest one (if we're -// going forward). -// Returns an error if there was an issue with retrieving the latest position -// from the database -func setToDefault( - ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, - roomID string, -) (to types.TopologyToken, err error) { - if backwardOrdering { - // go 1 earlier than the first event so we correctly fetch the earliest event - // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. - to = types.TopologyToken{} - } else { - to, err = snapshot.MaxTopologicalPosition(ctx, roomID) - } - - return -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 75afbce15..04c2020a0 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -45,8 +45,8 @@ type DatabaseTransaction interface { GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) - GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) - RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) + RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) @@ -84,8 +84,6 @@ type DatabaseTransaction interface { EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. @@ -134,6 +132,8 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error + // PurgeRoom entirely eliminates a room from the sync API, timeline, state and all. + PurgeRoom(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 8fc92091f..c20d860a7 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 48ed20021..0d607b7c0 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3 +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -122,6 +132,8 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectRoomHeroesStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - return nil, err - } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectCurrentStateStmt, selectCurrentStateSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL}, + {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectSharedUsersStmt, selectSharedUsersSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + {&s.selectRoomHeroesStmt, selectRoomHeroes}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -282,6 +275,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) senders, notSenders := getSendersStateFilterFilter(stateFilter) + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(senders), pq.StringArray(notSenders), @@ -447,3 +449,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt) + rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/cmd/dendrite-polylith-multi/personalities/mediaapi.go b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go similarity index 52% rename from cmd/dendrite-polylith-multi/personalities/mediaapi.go rename to syncapi/storage/postgres/deltas/20230201152200_rename_index.go index 69d5fd5a8..5a0ec5050 100644 --- a/cmd/dendrite-polylith-multi/personalities/mediaapi.go +++ b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go @@ -1,4 +1,4 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,25 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package personalities +package deltas import ( - "github.com/matrix-org/dendrite/mediaapi" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" + "context" + "database/sql" + "fmt" ) -func MediaAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - client := base.CreateClient() - - mediaapi.AddPublicRoutes( - base, userAPI, client, - ) - - base.SetupAndServeHTTP( - base.Cfg.MediaAPI.InternalAPI.Listen, - base.Cfg.MediaAPI.ExternalAPI.Listen, - nil, nil, - ) +func UpRenameOutputRoomEventsIndex(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE syncapi_output_room_events RENAME CONSTRAINT syncapi_event_id_idx TO syncapi_output_room_event_id_idx;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index aada70d5e..151bffa5d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index b555e8456..47833893a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,10 +19,8 @@ import ( "database/sql" "fmt" - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,25 +62,25 @@ const selectMembershipCountSQL = "" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + const selectMembersSQL = ` -SELECT event_id FROM ( - SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC -) t -WHERE ($3::text IS NULL OR t.membership = $3) - AND ($4::text IS NULL OR t.membership <> $4) + SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC + ) t + WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) ` type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -95,8 +93,8 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, - {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -129,26 +127,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships)) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. @@ -166,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 2c7b24800..7edfd54a6 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, }.Prepare(db) } @@ -44,6 +45,7 @@ type notificationDataStatements struct { upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt } const notificationDataSchema = ` @@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) return @@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3b69b26f6..59fb99aa3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -19,18 +19,17 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "sort" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -44,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( -- This isn't a problem for us since we just want to order by this field. id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), -- The event ID for the event - event_id TEXT NOT NULL CONSTRAINT syncapi_event_id_idx UNIQUE, + event_id TEXT NOT NULL CONSTRAINT syncapi_output_room_event_id_idx UNIQUE, -- The 'room_id' key for the event. room_id TEXT NOT NULL, -- The headered JSON for the event, containing potentially additional metadata such as @@ -79,13 +78,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_room_id_idx ON syncapi_out CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL)); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL)); +CREATE INDEX IF NOT EXISTS syncapi_output_room_events_recent_events_idx ON syncapi_output_room_events (room_id, exclude_from_sync, id, sender, type); + + ` const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + - "ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + + "ON CONFLICT ON CONSTRAINT syncapi_output_room_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + "RETURNING id" const selectEventsSQL = "" + @@ -109,14 +111,29 @@ const selectRecentEventsSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id DESC LIMIT $8" -const selectRecentEventsForSyncSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + - " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + - " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + - " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + - " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + - " ORDER BY id DESC LIMIT $8" +// selectRecentEventsForSyncSQL contains an optimization to get the recent events for a list of rooms, using a LATERAL JOIN +// The sub select inside LATERAL () is executed for all room_ids it gets as a parameter $1 +const selectRecentEventsForSyncSQL = ` +WITH room_ids AS ( + SELECT unnest($1::text[]) AS room_id +) +SELECT x.* +FROM room_ids, + LATERAL ( + SELECT room_id, event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility + FROM syncapi_output_room_events recent_events + WHERE + recent_events.room_id = room_ids.room_id + AND recent_events.exclude_from_sync = FALSE + AND id > $2 AND id <= $3 + AND ( $4::text[] IS NULL OR sender = ANY($4) ) + AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) ) + AND ( $6::text[] IS NULL OR type LIKE ANY($6) ) + AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) ) + ORDER BY recent_events.id DESC + LIMIT $8 + ) AS x +` const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + @@ -176,6 +193,9 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { @@ -193,6 +213,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt selectSearchStmt *sql.Stmt } @@ -203,12 +224,30 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return nil, err } + migrationName := "syncapi: rename dupe index (output_room_events)" + + var cName string + err = db.QueryRowContext(context.Background(), "select constraint_name from information_schema.table_constraints where table_name = 'syncapi_output_room_events' AND constraint_name = 'syncapi_event_id_idx'").Scan(&cName) + switch err { + case sql.ErrNoRows: // migration was already executed, as the index was renamed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return nil, fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + case nil: + default: + return nil, err + } + m := sqlutil.NewMigrator(db) m.AddMigrations( sqlutil.Migration{ Version: "syncapi: add history visibility column (output_room_events)", Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, }, + sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRenameOutputRoomEventsIndex, + }, ) err = m.Up(context.Background()) if err != nil { @@ -230,6 +269,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -393,9 +433,9 @@ func (s *outputRoomEventsStatements) InsertEvent( // from sync. func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var stmt *sql.Stmt if onlySyncEvents { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) @@ -403,8 +443,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } senders, notSenders := getSendersRoomEventFilter(eventFilter) + rows, err := stmt.QueryContext( - ctx, roomID, r.Low(), r.High(), + ctx, pq.StringArray(roomIDs), ra.Low(), ra.High(), pq.StringArray(senders), pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), @@ -412,34 +453,80 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( eventFilter.Limit+1, ) if err != nil { - return nil, false, err + return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + + result := make(map[string]types.RecentEvents) + + for rows.Next() { + var ( + roomID string + eventID string + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID + historyVisibility gomatrixserverlib.HistoryVisibility + ) + if err := rows.Scan(&roomID, &eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil { + return nil, err } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + if sessionID != nil && txnID != nil { + transactionID = &api.TransactionID{ + SessionID: *sessionID, + TransactionID: *txnID, + } + } + + r := result[roomID] + + ev.Visibility = historyVisibility + r.Events = append(r.Events, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, + }) + + result[roomID] = r } - return events, limited, nil + if chronologicalOrder { + for roomID, evs := range result { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(evs.Events, func(i int, j int) bool { + return evs.Events[i].StreamPosition < evs.Events[j].StreamPosition + }) + + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[1:] + } + + result[roomID] = evs + } + } else { + for roomID, evs := range result { + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[:len(evs.Events)-1] + } + + result[roomID] = evs + } + } + return result, rows.Err() } // selectEarlyEvents returns the earliest events in the given room, starting @@ -658,6 +745,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return result, rows.Err() } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 6fab900eb..2382fca5c 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,11 +18,12 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -65,28 +66,23 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" - // Select the max topological position for the room, then sort by stream position and take the highest, - // returning both topological and stream positions. -const selectMaxPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE topological_position=(" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + - ") ORDER BY stream_position DESC LIMIT 1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -95,28 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // InsertEventInTopology inserts the given event in the room's topology, based @@ -190,9 +173,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( return } -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err } diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index e20a4882f..64183073d 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB insertPeekStmt *sql.Stmt @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { @@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { s := &peekStatements{ db: db, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 327a7a372..0fcbebfcb 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -62,11 +62,15 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { @@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { r := &receiptStatements{ db: db, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { @@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index df2338cf8..aeeebb1d2 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -242,20 +242,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 77afa0290..931bc9e23 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "fmt" + "math" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" @@ -92,12 +94,65 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) } -func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { + summary := &types.Summary{Heroes: []string{}} + + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + if err != nil { + return summary, err + } + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + if err != nil { + return summary, err + } + summary.InvitedMemberCount = &inviteCount + summary.JoinedMemberCount = &joinCount + + // Get the room name and canonical alias, if any + filter := gomatrixserverlib.DefaultStateFilter() + filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filterRooms := []string{roomID} + + filter.Types = &filterTypes + filter.Rooms = &filterRooms + evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil) + if err != nil { + return summary, err + } + + for _, ev := range evs { + switch ev.Type() { + case gomatrixserverlib.MRoomName: + if gjson.GetBytes(ev.Content(), "name").Str != "" { + return summary, nil + } + case gomatrixserverlib.MRoomCanonicalAlias: + if gjson.GetBytes(ev.Content(), "alias").Str != "" { + return summary, nil + } + } + } + + // If there's no room name or canonical alias, get the room heroes, excluding the user + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + if err != nil { + return summary, err + } + + // "When no joined or invited members are available, this should consist of the banned and left users" + if len(heroes) == 0 { + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + if err != nil { + return summary, err + } + } + summary.Heroes = heroes + + return summary, nil } -func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents) } func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { @@ -215,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom( return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } -func (d *DatabaseTransaction) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - func (d *DatabaseTransaction) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { @@ -243,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil case err != nil: // some other error happened return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) default: @@ -329,19 +370,25 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered := state + // avoid hitting the database if the result would be the same as above + if !isStatefilterEmpty(stateFilter) { + var stateNeededFiltered map[string]map[string]bool + var eventMapFiltered map[string]types.StreamEvent + stateNeededFiltered, eventMapFiltered, err = d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err - } - stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered, err = d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err } // find out which rooms this user is peeking, if any. @@ -608,11 +655,80 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) return d.Presence.GetMaxPresenceID(ctx, d.txn) } +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge backward extremities: %w", err) + } + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge current room state: %w", err) + } + if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge invites: %w", err) + } + if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge memberships: %w", err) + } + if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge notification data: %w", err) + } + if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events: %w", err) + } + if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events topology: %w", err) + } + if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge peeks: %w", err) + } + if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge receipts: %w", err) + } + return nil + }) +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) { id, err := d.Relations.SelectMaxRelationID(ctx, d.txn) return types.StreamPosition(id), err } +func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool { + if filter == nil { + return true + } + switch { + case filter.NotTypes != nil && len(*filter.NotTypes) > 0: + return false + case filter.Types != nil && len(*filter.Types) > 0: + return false + case filter.Senders != nil && len(*filter.Senders) > 0: + return false + case filter.NotSenders != nil && len(*filter.NotSenders) > 0: + return false + case filter.NotRooms != nil && len(*filter.NotRooms) > 0: + return false + case filter.ContainsURL != nil: + return false + default: + return true + } +} + func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) ( events []types.StreamEvent, prevBatch, nextBatch string, err error, ) { diff --git a/syncapi/storage/shared/storage_sync_test.go b/syncapi/storage/shared/storage_sync_test.go new file mode 100644 index 000000000..c56720db7 --- /dev/null +++ b/syncapi/storage/shared/storage_sync_test.go @@ -0,0 +1,72 @@ +package shared + +import ( + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +func Test_isStatefilterEmpty(t *testing.T) { + filterSet := []string{"a"} + boolValue := false + + tests := []struct { + name string + filter *gomatrixserverlib.StateFilter + want bool + }{ + { + name: "nil filter is empty", + filter: nil, + want: true, + }, + { + name: "Empty filter is empty", + filter: &gomatrixserverlib.StateFilter{}, + want: true, + }, + { + name: "NotTypes is set", + filter: &gomatrixserverlib.StateFilter{ + NotTypes: &filterSet, + }, + }, + { + name: "Types is set", + filter: &gomatrixserverlib.StateFilter{ + Types: &filterSet, + }, + }, + { + name: "Senders is set", + filter: &gomatrixserverlib.StateFilter{ + Senders: &filterSet, + }, + }, + { + name: "NotSenders is set", + filter: &gomatrixserverlib.StateFilter{ + NotSenders: &filterSet, + }, + }, + { + name: "NotRooms is set", + filter: &gomatrixserverlib.StateFilter{ + NotRooms: &filterSet, + }, + }, + { + name: "ContainsURL is set", + filter: &gomatrixserverlib.StateFilter{ + ContainsURL: &boolValue, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isStatefilterEmpty(tt.filter); got != tt.want { + t.Errorf("isStatefilterEmpty() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index d8967113a..de0e72dbd 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -105,7 +105,9 @@ func (s *accountDataStatements) SelectAccountDataInRange( filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, []string{}, nil, filter.Limit, FilterOrderAsc) - + if err != nil { + return + } rows, err := stmt.QueryContext(ctx, params...) if err != nil { return diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 3a5fd6be3..2d8cf2ed2 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 7a381f68b..35b746c5c 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" @@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3) +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -107,6 +117,8 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic + selectMembershipCountStmt *sql.Stmt + //selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - // return nil, err - //} - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -270,6 +267,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } stmt, params, err := prepareWithFilters( s.db, txn, selectCurrentStateSQL, []interface{}{ @@ -485,3 +491,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + params := make([]interface{}, len(memberships)+2) + params[0] = roomID + params[1] = excludeUserID + for k, v := range memberships { + params[k+2] = v + } + + query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = s.db.Prepare(query) + } + if err != nil { + return []string{}, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index e2dbcd5c8..19450099a 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -64,6 +67,7 @@ type inviteEventsStatements struct { selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 7e54fac17..2cc46a10a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,11 +18,9 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -77,6 +72,9 @@ SELECT event_id FROM AND ($4 IS NULL OR t.membership <> $4) ` +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt @@ -84,6 +82,7 @@ type membershipsStatements struct { //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -99,7 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, - // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, }.Prepare(db) } @@ -131,39 +130,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.PrepareContext(ctx, stmtSQL) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed") - params := []interface{}{ - roomID, userID, - } - for _, membership := range memberships { - params = append(params, membership) - } - - stmt = sqlutil.TxStmt(txn, stmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, params...) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. @@ -181,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 6242898e1..af2b2c074 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } @@ -47,6 +48,7 @@ type notificationDataStatements struct { streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt //selectUserUnreadCountsForRooms *sql.Stmt } @@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { @@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1aa4bfff7..23bc68a41 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -130,6 +133,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt //selectSearchStmt *sql.Stmt - prepared at runtime } @@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -363,9 +368,9 @@ func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var query string if onlySyncEvents { query = selectRecentEventsForSyncSQL @@ -373,49 +378,55 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( query = selectRecentEventsSQL } - stmt, params, err := prepareWithFilters( - s.db, txn, query, - []interface{}{ - roomID, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, - ) - if err != nil { - return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, false, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + result := make(map[string]types.RecentEvents, len(roomIDs)) + for _, roomID := range roomIDs { + stmt, params, err := prepareWithFilters( + s.db, txn, query, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + res := types.RecentEvents{} + // we queried for 1 more than the limit, so if we returned one more mark limited=true + if len(events) > eventFilter.Limit { + res.Limited = true + // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. + if chronologicalOrder { + events = events[1:] + } else { + events = events[:len(events)-1] + } + } + res.Events = events + result[roomID] = res } - return events, limited, nil + + return result, nil } func (s *outputRoomEventsStatements) SelectEarlyEvents( @@ -666,6 +677,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ return } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { params := make([]interface{}, len(types)) for i := range types { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 81b264988..dc698de2d 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -61,25 +62,24 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -90,28 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // insertEventInTopology inserts the given event in the room's topology, based @@ -183,10 +170,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( return } -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 4ef51b103..5d5200abc 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { @@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index a4a9b4395..ca3d80fb4 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -58,12 +58,16 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { @@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } // UpsertReceipt creates new user receipts @@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 5ff185a32..05d498bc2 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "reflect" "testing" @@ -14,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" ) var ctx = context.Background() @@ -154,12 +156,12 @@ func TestRecentEventsPDU(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { var filter gomatrixserverlib.RoomEventFilter - var gotEvents []types.StreamEvent + var gotEvents map[string]types.RecentEvents var limited bool filter.Limit = tc.Limit WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { var err error - gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ + gotEvents, err = snapshot.RecentEvents(ctx, []string{r.ID}, types.Range{ From: tc.From, To: tc.To, }, &filter, !tc.ReverseOrder, true) @@ -167,15 +169,18 @@ func TestRecentEventsPDU(t *testing.T) { st.Fatalf("failed to do sync: %s", err) } }) + streamEvents := gotEvents[r.ID] + limited = streamEvents.Limited if limited != tc.WantLimited { st.Errorf("got limited=%v want %v", limited, tc.WantLimited) } - if len(gotEvents) != len(tc.WantEvents) { + if len(streamEvents.Events) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } - for j := range gotEvents { - if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { - st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + + for j := range streamEvents.Events { + if !reflect.DeepEqual(streamEvents.Events[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(streamEvents.Events[j].JSON()), string(tc.WantEvents[j].JSON())) } } }) @@ -198,10 +203,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { _ = MustWriteEvents(t, db, events) WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { - from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} t.Logf("max topo pos = %+v", from) // head towards the beginning of time to := types.TopologyToken{} @@ -218,6 +220,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { }) } +func TestStreamToTopologicalPosition(t *testing.T) { + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + + testCases := []struct { + name string + roomID string + streamPos types.StreamPosition + backwardOrdering bool + wantToken types.TopologyToken + }{ + { + name: "forward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "forward ordering not found streamPos returns max position", + roomID: r.ID, + streamPos: 100, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + { + name: "backward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "backward ordering not found streamPos returns maxDepth with param pduPosition", + roomID: r.ID, + streamPos: 100, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100}, + }, + { + name: "backward non-existent room returns zero token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1}, + }, + { + name: "forward non-existent room returns max token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + txn, err := db.NewDatabaseTransaction(ctx) + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + MustWriteEvents(t, db, r.Events()) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering) + if err != nil { + t.Fatal(err) + } + if tc.wantToken != token { + t.Fatalf("expected token %q, got %q", tc.wantToken, token) + } + }) + } + + }) +} + /* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent @@ -664,3 +748,239 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ return &tok } */ + +func pointer[t any](s t) *t { + return &s +} + +func TestRoomSummary(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + // Create some dummy users + moreUsers := []*test.User{} + moreUserIDs := []string{} + for i := 0; i < 10; i++ { + u := test.NewUser(t) + moreUsers = append(moreUsers, u) + moreUserIDs = append(moreUserIDs, u.ID) + } + + testCases := []struct { + name string + wantSummary *types.Summary + additionalEvents func(t *testing.T, room *test.Room) + }{ + { + name: "after initial creation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}}, + }, + { + name: "invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "invited user, but declined", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "joined user after invitation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited/left user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "leaving user after joining", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "many users", // heroes ordered by stream id + wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, + additionalEvents: func(t *testing.T, room *test.Room) { + for _, x := range moreUsers { + room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(x.ID)) + } + }, + }, + { + name: "canonical alias set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + "alias": "myalias", + }, test.WithStateKey("")) + }, + }, + { + name: "room name set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + "name": "my room name", + }, test.WithStateKey("")) + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + r := test.NewRoom(t, alice) + + if tc.additionalEvents != nil { + tc.additionalEvents(t, r) + } + + // write the room before creating a transaction + MustWriteEvents(t, db, r.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID) + assert.NoError(t, err) + assert.Equal(t, tc.wantSummary, summary) + }) + } + }) +} + +func TestRecentEvents(t *testing.T) { + alice := test.NewUser(t) + room1 := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice) + roomIDs := []string{room1.ID, room2.ID} + rooms := map[string]*test.Room{ + room1.ID: room1, + room2.ID: room2, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + filter := gomatrixserverlib.DefaultRoomEventFilter() + db, close, closeBase := MustCreateDatabase(t, dbType) + t.Cleanup(func() { + close() + closeBase() + }) + + MustWriteEvents(t, db, room1.Events()) + MustWriteEvents(t, db, room2.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + // get all recent events from 0 to 100 (we only created 5 events, so we should get 5 back) + roomEvs, err := transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for _, recentEvents := range roomEvs { + assert.Equal(t, 5, len(recentEvents.Events), "unexpected recent events for room") + } + + // update the filter to only return one event + filter.Limit = 1 + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + + // not chronologically ordered still returns the events in order (given ORDER BY id DESC) + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, false, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + }) +} diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index 23287c500..c7af4f977 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" ) func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { @@ -79,6 +80,9 @@ func TestCurrentRoomStateTable(t *testing.T) { return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id) } } + + testCurrentState(t, ctx, txn, tab, room) + return nil }) if err != nil { @@ -86,3 +90,39 @@ func TestCurrentRoomStateTable(t *testing.T) { } }) } + +func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) { + t.Run("test currentState", func(t *testing.T) { + // returns the complete state of the room with a default filter + filter := gomatrixserverlib.DefaultStateFilter() + evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + expectCount := 5 + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // When lazy loading, we expect no membership event, so only 4 events + filter.LazyLoadMembers = true + expectCount = 4 + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // same as above, but with existing NotTypes defined + notTypes := []string{gomatrixserverlib.MRoomMember} + filter.NotTypes = ¬Types + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + }) + +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a0574b257..727d6bf2c 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,7 @@ type Invites interface { // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } type Peeks interface { @@ -48,6 +49,7 @@ type Peeks interface { SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error } type Events interface { @@ -64,7 +66,7 @@ type Events interface { // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. - SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) @@ -75,6 +77,8 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + + PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } @@ -91,10 +95,9 @@ type Topology interface { SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) - // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. - SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) + PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error } type CurrentRoomState interface { @@ -115,6 +118,9 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) + + SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) + SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -145,6 +151,7 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error } // SendToDevice tracks send-to-device messages which are sent to individual @@ -180,13 +187,14 @@ type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) - SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, @@ -198,6 +206,7 @@ type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) + PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error } type Ignores interface { diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 0cee7f5a5..df593ae78 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -3,8 +3,6 @@ package tables_test import ( "context" "database/sql" - "reflect" - "sort" "testing" "time" @@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) { testUpsert(t, ctx, table, userEvents[0], alice, room) testMembershipCount(t, ctx, table, room) - testHeroes(t, ctx, table, alice, room, users) }) } -func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { - - // Re-slice and sort the expected users - users = users[1:] - sort.Strings(users) - type testCase struct { - name string - memberships []string - wantHeroes []string - } - - testCases := []testCase{ - {name: "no memberships queried", memberships: []string{}}, - {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) - if err != nil { - t.Fatalf("unable to select heroes: %s", err) - } - if gotLen := len(got); gotLen != len(tc.wantHeroes) { - t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) - } - - if !reflect.DeepEqual(got, tc.wantHeroes) { - t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) - } - }) - } -} - func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 7996c2038..e8189c352 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -3,17 +3,17 @@ package streams import ( "context" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type DeviceListStreamProvider struct { DefaultStreamProvider - rsAPI api.SyncRoomserverAPI - keyAPI keyapi.SyncKeyAPI + rsAPI api.SyncRoomserverAPI + userAPI userapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( @@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from } - err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response) + err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index dd7845574..6af25c028 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "sort" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -14,11 +13,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - - "github.com/matrix-org/dendrite/syncapi/notifier" ) // The max number of per-room goroutines to have running. @@ -85,19 +82,24 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("unable to update event filter with ignored users") } - // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars - // TODO: This might be inefficient, when joined to many and/or large rooms. + recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true) + if err != nil { + return from + } + // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { + events := recentEvents[roomID] + // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars + // TODO: This might be inefficient, when joined to many and/or large rooms. joinedUsers := p.notifier.JoinedUsers(roomID) for _, sharedUser := range joinedUsers { p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser) } - } - // Build up a /sync response. Add joined rooms. - for _, roomID := range joinedRoomIDs { + // get the join response for each room jr, jerr := p.getJoinResponseForCompleteSync( - ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false, + events.Events, events.Limited, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -116,11 +118,25 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("p.DB.PeeksInRange failed") return from } - for _, peek := range peeks { - if !peek.Deleted { + if len(peeks) > 0 { + peekRooms := make([]string, 0, len(peeks)) + for _, peek := range peeks { + if !peek.Deleted { + peekRooms = append(peekRooms, peek.RoomID) + } + } + + recentEvents, err = snapshot.RecentEvents(ctx, peekRooms, r, &eventFilter, true, true) + if err != nil { + return from + } + + for _, roomID := range peekRooms { var jr *types.JoinResponse + events := recentEvents[roomID] jr, err = p.getJoinResponseForCompleteSync( - ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true, + events.Events, events.Limited, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -129,7 +145,7 @@ func (p *PDUStreamProvider) CompleteSync( } continue } - req.Response.Rooms.Peek[peek.RoomID] = jr + req.Response.Rooms.Peek[roomID] = jr } } @@ -230,7 +246,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( stateFilter *gomatrixserverlib.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { - + var err error originalLimit := eventFilter.Limit // If we're going backwards, grep at least X events, this is mostly to satisfy Sytest if r.Backwards && originalLimit < recentEventBackwardsLimit { @@ -241,8 +257,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, delta.RoomID, r, + dbEvents, err := snapshot.RecentEvents( + ctx, []string{delta.RoomID}, r, eventFilter, true, true, ) if err != nil { @@ -251,6 +267,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } + + recentStreamEvents := dbEvents[delta.RoomID].Events + limited := dbEvents[delta.RoomID].Limited + recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( snapshot.StreamEventsToEvents(device, recentStreamEvents), gomatrixserverlib.TopologicalOrderByPrevEvents, @@ -339,7 +359,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) + jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -384,19 +407,32 @@ func applyHistoryVisibilityFilter( roomID, userID string, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - // We need to make sure we always include the latest states events, if they are in the timeline. - // We grep at least limit * 2 events, to ensure we really get the needed events. - filter := gomatrixserverlib.DefaultStateFilter() - stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) - if err != nil { - // Not a fatal error, we can continue without the stateEvents, - // they are only needed if there are state events in the timeline. - logrus.WithError(err).Warnf("Failed to get current room state for history visibility") + // We need to make sure we always include the latest state events, if they are in the timeline. + alwaysIncludeIDs := make(map[string]struct{}) + var stateTypes []string + var senders []string + for _, ev := range recentEvents { + if ev.StateKey() != nil { + stateTypes = append(stateTypes, ev.Type()) + senders = append(senders, ev.Sender()) + } } - alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) - for _, ev := range stateEvents { - alwaysIncludeIDs[ev.EventID()] = struct{}{} + + // Only get the state again if there are state events in the timeline + if len(stateTypes) > 0 { + filter := gomatrixserverlib.DefaultStateFilter() + filter.Types = &stateTypes + filter.Senders = &senders + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) + if err != nil { + return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err) + } + + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } } + startTime := time.Now() events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { @@ -407,72 +443,24 @@ func applyHistoryVisibilityFilter( "room_id": roomID, "before": len(recentEvents), "after": len(events), - }).Trace("Applied history visibility (sync)") + }).Debugf("Applied history visibility (sync)") return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { - // Work out how many members are in the room. - joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) - - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount - - fetchStates := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomName}, - {EventType: gomatrixserverlib.MRoomCanonicalAlias}, - } - // Check if the room has a name or a canonical alias - latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{} - err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState) - if err != nil { - return - } - // Check if the room has a name or canonical alias, if so, return. - for _, ev := range latestState.StateEvents { - switch ev.Type() { - case gomatrixserverlib.MRoomName: - if gjson.GetBytes(ev.Content(), "name").Str != "" { - return - } - case gomatrixserverlib.MRoomCanonicalAlias: - if gjson.GetBytes(ev.Content(), "alias").Str != "" { - return - } - } - } - heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) - if err != nil { - return - } - sort.Strings(heroes) - jr.Summary.Heroes = heroes -} - func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, - r types.Range, stateFilter *gomatrixserverlib.StateFilter, - eventFilter *gomatrixserverlib.RoomEventFilter, wantFullState bool, device *userapi.Device, isPeek bool, + recentStreamEvents []types.StreamEvent, + limited bool, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, roomID, r, eventFilter, true, true, - ) - if err != nil { - if err == sql.ErrNoRows { - return jr, nil - } - return - } // Work our way through the timeline events and pick out the event IDs // of any state events that appear in the timeline. We'll specifically @@ -493,7 +481,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) + jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index dc8547621..a35491acf 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -5,7 +5,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" @@ -27,7 +26,7 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.SyncUserAPI, - rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI, + rsAPI rsapi.SyncRoomserverAPI, eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ @@ -60,7 +59,7 @@ func NewSyncStreamProviders( DeviceListStreamProvider: &DeviceListStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b086567b8..68f913871 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -48,7 +47,6 @@ type RequestPool struct { db storage.Database cfg *config.SyncAPI userAPI userapi.SyncUserAPI - keyAPI keyapi.SyncKeyAPI rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map presence *sync.Map @@ -69,7 +67,7 @@ type PresenceConsumer interface { // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, - userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, + userAPI userapi.SyncUserAPI, rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool, @@ -83,7 +81,6 @@ func NewRequestPool( db: db, cfg: cfg, userAPI: userAPI, - keyAPI: keyAPI, rsAPI: rsAPI, lastseen: &sync.Map{}, presence: &sync.Map{}, @@ -280,7 +277,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 // Only try to get OTKs if the context isn't already done. if syncReq.Context.Err() == nil { - err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) + err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } @@ -551,7 +548,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index be19310f2..153f7af5d 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -42,7 +41,6 @@ func AddPublicRoutes( base *base.BaseDendrite, userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, - keyAPI keyapi.SyncKeyAPI, ) { cfg := &base.Cfg.SyncAPI @@ -55,7 +53,7 @@ func AddPublicRoutes( eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") @@ -71,7 +69,7 @@ func AddPublicRoutes( userAPI, ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start presence consumer") @@ -117,7 +115,7 @@ func AddPublicRoutes( } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 483274481..e748660f7 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -10,12 +10,12 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/producers" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -82,21 +82,18 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc return nil } +func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { + return nil +} + +func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { + return nil +} + func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { return nil } -type syncKeyAPI struct { - keyapi.SyncKeyAPI -} - -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { - return nil -} -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { - return nil -} - func TestSyncAPIAccessTokens(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { testSyncAccessTokens(t, dbType) @@ -120,7 +117,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -219,7 +216,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -303,7 +300,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -421,7 +418,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -448,6 +445,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ "access_token": bobDev.AccessToken, "dir": "b", + "filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility }))) if w.Code != 200 { t.Logf("%s", w.Body.String()) @@ -521,6 +519,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr } } +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + bob := test.NewUser(t) + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "notjoinedtoanyrooms", + } + + testCases := []struct { + name string + roomID string + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room) *http.Request + wantOK bool + wantMemberCount int + useSleep bool // :/ + }{ + { + name: "/members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + }, + { + name: "Alice leaves before Bob joins, should not be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "Alice leaves after Bob joins, should be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 2, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: false, + }, + { + name: "'at' specified, returns memberships before Bob joins", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "at": "t2_5", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "'membership=leave' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "membership": "leave", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=join' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "join", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "leave", + "membership": "join", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "non-existent room ID", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: false, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // wait for the events to come down sync + if tc.useSleep { + time.Sleep(time.Millisecond * 100) + } else { + syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + } + + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } @@ -541,7 +785,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -659,6 +903,177 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } +func TestContext(t *testing.T) { + test.WithAllDatabases(t, testContext) +} + +func testContext(t *testing.T, dbType test.DBType) { + + tests := []struct { + name string + roomID string + eventID string + params map[string]string + wantError bool + wantStateLength int + wantBeforeLength int + wantAfterLength int + }{ + { + name: "invalid filter", + params: map[string]string{ + "filter": "{", + }, + wantError: true, + }, + { + name: "invalid limit", + params: map[string]string{ + "limit": "abc", + }, + wantError: true, + }, + { + name: "high limit", + params: map[string]string{ + "limit": "100000", + }, + }, + { + name: "fine limit", + params: map[string]string{ + "limit": "10", + }, + }, + { + name: "last event without lazy loading", + wantStateLength: 5, + }, + { + name: "last event with lazy loading", + params: map[string]string{ + "filter": `{"lazy_load_members":true}`, + }, + wantStateLength: 1, + }, + { + name: "invalid room", + roomID: "!doesnotexist", + wantError: true, + }, + { + name: "invalid eventID", + eventID: "$doesnotexist", + wantError: true, + }, + { + name: "state is limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + }, + { + name: "events are not limited", + wantBeforeLength: 7, + }, + { + name: "all events are limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + wantBeforeLength: 1, + wantAfterLength: 1, + }, + } + + user := test.NewUser(t) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI) + + room := test.NewRoom(t, user) + + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 1!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 2!"}) + thirdMsg := room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world3!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world4!"}) + + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID()) + return gjson.Get(syncBody, path).Exists() + }) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + params := map[string]string{ + "access_token": alice.AccessToken, + } + w := httptest.NewRecorder() + // test overrides + roomID := room.ID + if tc.roomID != "" { + roomID = tc.roomID + } + eventID := thirdMsg.EventID() + if tc.eventID != "" { + eventID = tc.eventID + } + requestPath := fmt.Sprintf("/_matrix/client/v3/rooms/%s/context/%s", roomID, eventID) + if tc.params != nil { + for k, v := range tc.params { + params[k] = v + } + } + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params))) + + if tc.wantError && w.Code == 200 { + t.Fatalf("Expected an error, but got none") + } + t.Log(w.Body.String()) + resp := routing.ContextRespsonse{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if tc.wantStateLength > 0 && tc.wantStateLength != len(resp.State) { + t.Fatalf("expected %d state events, got %d", tc.wantStateLength, len(resp.State)) + } + if tc.wantBeforeLength > 0 && tc.wantBeforeLength != len(resp.EventsBefore) { + t.Fatalf("expected %d before events, got %d", tc.wantBeforeLength, len(resp.EventsBefore)) + } + if tc.wantAfterLength > 0 && tc.wantAfterLength != len(resp.EventsAfter) { + t.Fatalf("expected %d after events, got %d", tc.wantAfterLength, len(resp.EventsAfter)) + } + + if !tc.wantError && resp.Event.EventID != eventID { + t.Fatalf("unexpected eventID %s, expected %s", resp.Event.EventID, eventID) + } + }) + } +} + func syncUntil(t *testing.T, base *base.BaseDendrite, accessToken string, skip bool, diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 9fbadc06c..6495dd535 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -63,6 +63,11 @@ type StreamEvent struct { ExcludeFromSync bool } +type RecentEvents struct { + Limited bool + Events []StreamEvent +} + // Range represents a range between two stream positions. type Range struct { // From is the position the client has already received. diff --git a/sytest-blacklist b/sytest-blacklist index 99cfbabc8..49a3cc870 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,54 +1,19 @@ -# Relies on a rejected PL event which will never be accepted into the DAG - -# Caused by - -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -# We don't implement lazy membership loading yet - +# Blacklisted due to https://github.com/matrix-org/matrix-spec/issues/942 The only membership state included in a gapped incremental sync is for senders in the timeline -# Blacklisted out of flakiness after #1479 - -Invited user can reject local invite after originator leaves -Invited user can reject invite for empty room -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes - -# Blacklisted due to flakiness - -Forgotten room messages cannot be paginated - -# Blacklisted due to flakiness after #1774 - -Local device key changes get to remote servers with correct prev_id - -# we don't support groups - -Remove group category -Remove group role - # Flakey - AS-ghosted users can use rooms themselves AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server - -# More flakey - -Guest users can join guest_access rooms +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # This will fail in HTTP API mode, so blacklisted for now - If a device list update goes missing, the server resyncs on the next one # Might be a bug in the test because leaves do appear :-( - Leaves are present in non-gapped incremental syncs -# Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) - # We don't have any state to calculate m.room.guest_access when accepting invites Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 215889a49..c61e0bc3c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -766,4 +766,21 @@ remote user has tags copied to the new room Local and remote users' homeservers remove a room from their public directory on upgrade Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access -Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file +Guest users are kicked from guest_access rooms on revocation of guest_access over federation +User can create and send/receive messages in a room with version 10 +local user can join room with version 10 +User can invite local user to room with version 10 +remote user can join room with version 10 +User can invite remote user to room with version 10 +Remote user can backfill in a room with version 10 +Can reject invites over federation for rooms with version 10 +Can receive redactions from regular users over federation in room version 10 +New federated private chats get full presence information (SYN-115) +/state returns M_NOT_FOUND for an outlier +/state_ids returns M_NOT_FOUND for an outlier +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Guest users can join guest_access rooms +Forgotten room messages cannot be paginated +Local device key changes get to remote servers with correct prev_id \ No newline at end of file diff --git a/test/db.go b/test/db.go index 17f637e18..d2f405d49 100644 --- a/test/db.go +++ b/test/db.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "testing" "github.com/lib/pq" @@ -100,16 +101,12 @@ func currentUser() string { // Returns the connection string to use and a close function which must be called when the test finishes. // Calling this function twice will return the same database, which will have data from previous tests // unless close() is called. -// TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { - // this will be made in the current working directory which namespaces concurrent package runs correctly - dbname := "dendrite_test.db" + // this will be made in the t.TempDir, which is unique per test + dbname := filepath.Join(t.TempDir(), "dendrite_test.db") return fmt.Sprintf("file:%s", dbname), func() { - err := os.Remove(dbname) - if err != nil { - t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) - } + t.Cleanup(func() {}) // removes the t.TempDir } } @@ -176,7 +173,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { - //tt.Parallel() + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go new file mode 100644 index 000000000..de0dc54eb --- /dev/null +++ b/test/memory_federation_db.go @@ -0,0 +1,511 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var nidMutex sync.Mutex +var nid = int64(0) + +type InMemoryFederationDatabase struct { + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName +} + +func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { + return &InMemoryFederationDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + } +} + +func (d *InMemoryFederationDatabase) StoreJSON( + ctx context.Context, + js string, +) (*receipt.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingPDUs[&newReceipt] = &event + return &newReceipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingEDUs[&newReceipt] = &edu + return &newReceipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *InMemoryFederationDatabase) GetPendingPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedPDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, + eduType string, + expireEDUTypes map[string]time.Duration, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedEDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + knownRelayServers := []gomatrixserverlib.ServerName{} + if relayServers, ok := d.relayServers[serverName]; ok { + knownRelayServers = relayServers + } + + return knownRelayServers, nil +} + +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + alreadyKnown := false + for _, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + alreadyKnown = true + } + } + if !alreadyKnown { + d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + for i, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + d.relayServers[serverName] = append( + d.relayServers[serverName][:i], + d.relayServers[serverName][i+1:]..., + ) + break + } + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) FetcherName() string { + return "" +} + +func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} + +func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error { + return nil +} diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go new file mode 100644 index 000000000..db93919df --- /dev/null +++ b/test/memory_relay_db.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type InMemoryRelayDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { + return &InMemoryRelayDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *InMemoryRelayDatabase) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *InMemoryRelayDatabase) InsertQueueJSON( + ctx context.Context, + txn *sql.Tx, + json string, +) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueJSON( + ctx context.Context, + txn *sql.Tx, + nids []int64, +) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueJSON( + ctx context.Context, + txn *sql.Tx, + jsonNIDs []int64, +) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/test/testrig/base.go b/test/testrig/base.go index 7bc26a5c5..bb8fded21 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -15,41 +15,37 @@ package testrig import ( - "errors" "fmt" - "io/fs" - "os" - "strings" + "path/filepath" "testing" - "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { var cfg config.Dendrite cfg.Defaults(config.DefaultOpts{ - Generate: false, - Monolithic: true, + Generate: false, + SingleDatabase: true, }) cfg.Global.JetStream.InMemory = true cfg.FederationAPI.KeyPerspectives = nil switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.MediaAPI.Defaults(config.DefaultOpts{ // autogen a media path - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.SyncAPI.Fulltext.Defaults(config.DefaultOpts{ // use in memory fts - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = "test" // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use @@ -62,34 +58,39 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close + base := base.NewBaseDendrite(&cfg, base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + close() + } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: false, // because we need a database per component + Generate: true, + SingleDatabase: false, }) cfg.Global.ServerName = "test" + // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { - // cleanup db files. This risks getting out of sync as we add more database strings :( - dbFiles := []config.DataSource{ - cfg.FederationAPI.Database.ConnectionString, - cfg.KeyServer.Database.ConnectionString, - cfg.MSCs.Database.ConnectionString, - cfg.MediaAPI.Database.ConnectionString, - cfg.RoomServer.Database.ConnectionString, - cfg.SyncAPI.Database.ConnectionString, - cfg.UserAPI.AccountDatabase.ConnectionString, - } - for _, fileURI := range dbFiles { - path := strings.TrimPrefix(string(fileURI), "file:") - err := os.Remove(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("failed to cleanup sqlite db '%s': %s", fileURI, err) - } - } + + // Use a temp dir provided by go for tests, this will be cleanup by a call to t.CleanUp() + tempDir := t.TempDir() + cfg.FederationAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "federationapi.db")) + cfg.KeyServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "keyserver.db")) + cfg.MSCs.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mscs.db")) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mediaapi.db")) + cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db")) + + base := base.NewBaseDendrite(&cfg, base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) @@ -101,14 +102,14 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat if cfg == nil { cfg = &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) } cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics) + base := base.NewBaseDendrite(cfg, base.DisableMetrics) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc } diff --git a/userapi/api/api.go b/userapi/api/api.go index 4ea2e91c3..fa297f773 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -15,9 +15,13 @@ package api import ( + "bytes" "context" "encoding/json" + "strings" + "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -26,15 +30,12 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - AppserviceUserAPI SyncUserAPI ClientUserAPI - MediaUserAPI FederationUserAPI - RoomserverUserAPI - KeyserverUserAPI QuerySearchProfilesAPI // used by p2p demos + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the appservice api @@ -43,11 +44,6 @@ type AppserviceUserAPI interface { PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error } -type KeyserverUserAPI interface { - QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error -} - type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) @@ -60,13 +56,20 @@ type MediaUserAPI interface { // api functions required by the federation api type FederationUserAPI interface { + UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -79,6 +82,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + ClientKeyAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } + +// API functions required by the clientapi +type ClientKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type UploadDeviceKeysAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error +} + +// API functions required by the syncapi +type SyncKeyAPI interface { + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type FederationKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error +} + +// KeyError is returned if there was a problem performing/querying the server +type KeyError struct { + Err string `json:"error"` + IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE + IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM + IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM +} + +func (k *KeyError) Error() string { + return k.Err +} + +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` + *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` + // A monotonically increasing number which represents device changes for this user. + StreamID int64 + DeviceChangeID int64 +} + +// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log +type OutputCrossSigningKeyUpdate struct { + CrossSigningKeyUpdate `json:"signing_keys"` +} + +type CrossSigningKeyUpdate struct { + MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` +} + +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + +// DeviceKeys represents a set of device keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type DeviceKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // The device display name + DisplayName string + // The raw device key JSON + KeyJSON []byte +} + +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { + return DeviceMessage{ + DeviceKeys: k, + StreamID: streamID, + } +} + +// OneTimeKeys represents a set of one-time keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type OneTimeKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + +// OneTimeKeysCount represents the counts of one-time keys for a single device +type OneTimeKeysCount struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // algorithm to count e.g: + // { + // "curve25519": 10, + // "signed_curve25519": 20 + // } + KeyCount map[string]int +} + +// PerformUploadKeysRequest is the request to PerformUploadKeys +type PerformUploadKeysRequest struct { + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool +} + +// PerformUploadKeysResponse is the response to PerformUploadKeys +type PerformUploadKeysResponse struct { + // A fatal error when processing e.g database failures + Error *KeyError + // A map of user_id -> device_id -> Error for tracking failures. + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount +} + +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + +type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration +} + +type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action + Error *KeyError +} + +type PerformUploadDeviceKeysRequest struct { + gomatrixserverlib.CrossSigningKeys + // The user that uploaded the key, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceKeysResponse struct { + Error *KeyError +} + +type PerformUploadDeviceSignaturesRequest struct { + Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + // The user that uploaded the sig, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceSignaturesResponse struct { + Error *KeyError +} + +type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration +} + +type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + +type QueryKeyChangesRequest struct { + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 + // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + ToOffset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int64 + Devices []DeviceMessage + Error *KeyError +} + +type QuerySignaturesRequest struct { + // A map of target user ID -> target key/device IDs to retrieve signatures for + TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` +} + +type QuerySignaturesResponse struct { + // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures + Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing user-signing key + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError +} + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go deleted file mode 100644 index d10b5767b..000000000 --- a/userapi/api/api_trace.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/matrix-org/util" -) - -// UserInternalAPITrace wraps a RoomserverInternalAPI and logs the -// complete request/response/error -type UserInternalAPITrace struct { - Impl UserInternalAPI -} - -func (t *UserInternalAPITrace) InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error { - err := t.Impl.InputAccountData(ctx, req, res) - util.GetLogger(ctx).Infof("InputAccountData req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error { - err := t.Impl.PerformAccountCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAccountCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error { - err := t.Impl.PerformPasswordUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPasswordUpdate req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error { - err := t.Impl.PerformDeviceCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error { - err := t.Impl.PerformDeviceDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceDeletion req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error { - err := t.Impl.PerformLastSeenUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLastSeenUpdate req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error { - err := t.Impl.PerformDeviceUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceUpdate req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error { - err := t.Impl.PerformAccountDeactivation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAccountDeactivation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error { - err := t.Impl.PerformOpenIDTokenCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformOpenIDTokenCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error { - err := t.Impl.PerformKeyBackup(ctx, req, res) - util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { - err := t.Impl.PerformPusherSet(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { - err := t.Impl.PerformPusherDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { - err := t.Impl.PerformPushRulesPut(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error { - err := t.Impl.QueryKeyBackup(ctx, req, res) - util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { - err := t.Impl.QueryProfile(ctx, req, res) - util.GetLogger(ctx).Infof("QueryProfile req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error { - err := t.Impl.QueryAccessToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccessToken req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error { - err := t.Impl.QueryDevices(ctx, req, res) - util.GetLogger(ctx).Infof("QueryDevices req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error { - err := t.Impl.QueryAccountData(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountData req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error { - err := t.Impl.QueryDeviceInfos(ctx, req, res) - util.GetLogger(ctx).Infof("QueryDeviceInfos req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error { - err := t.Impl.QuerySearchProfiles(ctx, req, res) - util.GetLogger(ctx).Infof("QuerySearchProfiles req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error { - err := t.Impl.QueryOpenIDToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { - err := t.Impl.QueryPushers(ctx, req, res) - util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { - err := t.Impl.QueryPushRules(ctx, req, res) - util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { - err := t.Impl.QueryNotifications(ctx, req, res) - util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error { - err := t.Impl.SetAvatarURL(ctx, req, res) - util.GetLogger(ctx).Infof("SetAvatarURL req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, req, res) - util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error { - err := t.Impl.QueryAccountAvailability(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountAvailability req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { - err := t.Impl.SetDisplayName(ctx, req, res) - util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error { - err := t.Impl.QueryAccountByPassword(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountByPassword req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error { - err := t.Impl.QueryLocalpartForThreePID(ctx, req, res) - util.GetLogger(ctx).Infof("QueryLocalpartForThreePID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error { - err := t.Impl.QueryThreePIDsForLocalpart(ctx, req, res) - util.GetLogger(ctx).Infof("QueryThreePIDsForLocalpart req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error { - err := t.Impl.PerformForgetThreePID(ctx, req, res) - util.GetLogger(ctx).Infof("PerformForgetThreePID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error { - err := t.Impl.PerformSaveThreePIDAssociation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformSaveThreePIDAssociation req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { - err := t.Impl.QueryAccountByLocalpart(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) - return err -} - -func js(thing interface{}) string { - b, err := json.Marshal(thing) - if err != nil { - return fmt.Sprintf("Marshal error:%s", err) - } - return string(b) -} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go deleted file mode 100644 index e60dae594..000000000 --- a/userapi/api/api_trace_logintoken.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - - "github.com/matrix-org/util" -) - -func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { - err := t.Impl.PerformLoginTokenCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { - err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { - err := t.Impl.QueryLoginToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) - return err -} diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e77..51bd2753a 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/keyserver/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go similarity index 97% rename from keyserver/consumers/devicelistupdate.go rename to userapi/consumers/devicelistupdate.go index cd911f8c6..a65889fcc 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -18,11 +18,11 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -41,7 +41,7 @@ type DeviceListUpdateConsumer struct { // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. func NewDeviceListUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3ce5af621..47d330959 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 39f4aab4a..bc5ae652d 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,11 +18,11 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { diff --git a/keyserver/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go similarity index 87% rename from keyserver/consumers/signingkeyupdate.go rename to userapi/consumers/signingkeyupdate.go index bcceaad15..f4ff017db 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -22,11 +22,10 @@ import ( "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. @@ -35,24 +34,24 @@ type SigningKeyUpdateConsumer struct { jetstream nats.JetStreamContext durable string topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + userAPI api.UploadDeviceKeysAPI + cfg *config.UserAPI isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. func NewSigningKeyUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, - keyAPI *internal.KeyInternalAPI, + userAPI api.UploadDeviceKeysAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ ctx: process.Context(), jetstream: js, durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, + userAPI: userAPI, cfg: cfg, isLocalServerName: cfg.Matrix.IsLocalServerName, } @@ -70,7 +69,7 @@ func (t *SigningKeyUpdateConsumer) Start() error { // signing key update events topic from the key server. func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called - var updatePayload keyapi.CrossSigningKeyUpdate + var updatePayload api.CrossSigningKeyUpdate if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { logrus.WithError(err).Errorf("Failed to read from signing key update input topic") return true @@ -94,12 +93,12 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if updatePayload.SelfSigningKey != nil { keys.SelfSigningKey = *updatePayload.SelfSigningKey } - uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ + uploadReq := &api.PerformUploadDeviceKeysRequest{ CrossSigningKeys: keys, UserID: updatePayload.UserID, } - uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { logrus.WithError(err).Error("failed to upload device keys") return false } diff --git a/keyserver/internal/cross_signing.go b/userapi/internal/cross_signing.go similarity index 91% rename from keyserver/internal/cross_signing.go rename to userapi/internal/cross_signing.go index 99859dff6..8b9704d1b 100644 --- a/keyserver/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -22,8 +22,8 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" @@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } // nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -169,7 +169,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P // something if any of the specified keys in the request are different // to what we've got in the database, to avoid generating key change // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), @@ -216,7 +216,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } @@ -234,7 +234,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P continue } for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } @@ -257,7 +257,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P if update.MasterKey == nil && update.SelfSigningKey == nil { return nil } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -266,7 +266,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P return nil } -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -342,7 +342,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req MasterKey: &masterKey, SelfSigningKey: &selfSigningKey, } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -352,7 +352,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req return nil } -func (a *KeyInternalAPI) processSelfSignatures( +func (a *UserInternalAPI) processSelfSignatures( ctx context.Context, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -373,7 +373,7 @@ func (a *KeyInternalAPI) processSelfSignatures( } for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -384,7 +384,7 @@ func (a *KeyInternalAPI) processSelfSignatures( case *gomatrixserverlib.DeviceKeys: for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -401,7 +401,7 @@ func (a *KeyInternalAPI) processSelfSignatures( return nil } -func (a *KeyInternalAPI) processOtherSignatures( +func (a *UserInternalAPI) processOtherSignatures( ctx context.Context, userID string, queryRes *api.QueryKeysResponse, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -442,7 +442,7 @@ func (a *KeyInternalAPI) processOtherSignatures( } for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -461,11 +461,11 @@ func (a *KeyInternalAPI) processOtherSignatures( return nil } -func (a *KeyInternalAPI) crossSigningKeysFromDatabase( +func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue @@ -478,7 +478,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( break } - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue @@ -522,9 +522,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( } } -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), @@ -556,7 +556,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign for _, targetKeyID := range forTargetUser { // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), diff --git a/keyserver/internal/device_list_update.go b/userapi/internal/device_list_update.go similarity index 98% rename from keyserver/internal/device_list_update.go rename to userapi/internal/device_list_update.go index c7bf8da53..3b4dcf98e 100644 --- a/keyserver/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -33,8 +33,8 @@ import ( "github.com/sirupsen/logrus" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) var ( @@ -477,10 +477,6 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go return e.RetryAfter, err } else if e.Blacklisted { return time.Hour * 8, err - } else if e.Code >= 300 { - // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError - // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - return hourWaitTime, err } case net.Error: // Use the default waitTime, if it's a timeout. diff --git a/keyserver/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go similarity index 100% rename from keyserver/internal/device_list_update_default.go rename to userapi/internal/device_list_update_default.go diff --git a/keyserver/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go similarity index 100% rename from keyserver/internal/device_list_update_sytest.go rename to userapi/internal/device_list_update_sytest.go diff --git a/keyserver/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go similarity index 98% rename from keyserver/internal/device_list_update_test.go rename to userapi/internal/device_list_update_test.go index 60a2c2f30..868fc9be8 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -29,13 +29,13 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -360,12 +360,12 @@ func TestDebounce(t *testing.T) { } } -func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() base, _, _ := testrig.Base(nil) connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) } diff --git a/keyserver/internal/internal.go b/userapi/internal/key_api.go similarity index 83% rename from keyserver/internal/internal.go rename to userapi/internal/key_api.go index 9a08a0bb7..be816fe5d 100644 --- a/keyserver/internal/internal.go +++ b/userapi/internal/key_api.go @@ -29,29 +29,11 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -type KeyInternalAPI struct { - DB storage.Database - Cfg *config.KeyServer - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater -} - -func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - a.UserAPI = i -} - -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) +func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), @@ -63,7 +45,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC return nil } -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { +func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { res.KeyErrors = make(map[string]map[string]*api.KeyError) if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) @@ -71,7 +53,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } @@ -79,7 +61,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -97,11 +79,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC domainToDeviceKeys[string(serverName)] = nested } for domain, local := range domainToDeviceKeys { - if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), @@ -129,7 +111,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC return nil } -func (a *KeyInternalAPI) claimRemoteKeys( +func (a *UserInternalAPI) claimRemoteKeys( ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ) { var wg sync.WaitGroup // Wait for fan-out goroutines to finish @@ -146,7 +128,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -177,8 +159,8 @@ func (a *KeyInternalAPI) claimRemoteKeys( }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) } -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { +func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } @@ -186,8 +168,8 @@ func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) +func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), @@ -198,8 +180,8 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne return nil } -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) +func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -225,8 +207,8 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query // PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present // in our database. -func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) +func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) if err != nil { return err } @@ -244,7 +226,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap } // nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -262,8 +244,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if a.Cfg.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + if a.Config.Matrix.IsLocalServerName(serverName) { + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -276,8 +258,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for _, dk := range deviceKeys { dids = append(dids, dk.DeviceID) } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + var queryRes api.QueryDeviceInfosResponse + err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{ DeviceIDs: dids, }, &queryRes) if err != nil { @@ -341,14 +323,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} } for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -367,14 +349,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for targetUserID, forUserID := range res.DeviceKeys { for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -403,7 +385,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques return nil } -func (a *KeyInternalAPI) remoteKeysFromDatabase( +func (a *UserInternalAPI) remoteKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) @@ -429,7 +411,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( return fetchRemote } -func (a *KeyInternalAPI) queryRemoteKeys( +func (a *UserInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, ) { @@ -441,13 +423,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -499,7 +481,7 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) queryRemoteKeysOnServer( +func (a *UserInternalAPI) queryRemoteKeysOnServer( ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, res *api.QueryKeysResponse, @@ -559,7 +541,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -586,10 +568,10 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( respMu.Unlock() } -func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( +func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -621,11 +603,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( return nil } -func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { // get a list of devices from the user API that actually exist, as // we won't store keys for devices that don't exist - uapidevices := &userapi.QueryDevicesResponse{} - if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + uapidevices := &api.QueryDevicesResponse{} + if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { res.Error = &api.KeyError{ Err: err.Error(), } @@ -643,7 +625,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), @@ -662,7 +644,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) } else { logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) @@ -693,7 +675,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if !a.Cfg.Matrix.IsLocalServerName(serverName) { + if !a.Config.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { @@ -722,30 +704,30 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) + err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } } -func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { if req.UserID == "" { res.Error = &api.KeyError{ Err: "user ID missing", } } if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), } } if counts != nil { @@ -761,7 +743,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: "failed to query existing one-time keys: " + err.Error(), @@ -778,7 +760,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } } // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), diff --git a/keyserver/internal/internal_test.go b/userapi/internal/key_api_test.go similarity index 89% rename from keyserver/internal/internal_test.go rename to userapi/internal/key_api_test.go index 8a2c9c5d9..fc7e7e0df 100644 --- a/keyserver/internal/internal_test.go +++ b/userapi/internal/key_api_test.go @@ -5,23 +5,28 @@ import ( "reflect" "testing" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/storage" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(nil, &config.DatabaseOptions{ + base, _, _ := testrig.Base(nil) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + base.Close() + close() + } } func Test_QueryDeviceMessages(t *testing.T) { @@ -140,8 +145,8 @@ func Test_QueryDeviceMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &internal.KeyInternalAPI{ - DB: db, + a := &internal.UserInternalAPI{ + KeyDatabase: db, } if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go similarity index 95% rename from userapi/internal/api.go rename to userapi/internal/user_api.go index 0bb480da6..8977697b5 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/user_api.go @@ -23,6 +23,7 @@ import ( "strconv" "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -32,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" synctypes "github.com/matrix-org/dendrite/syncapi/types" @@ -44,17 +44,19 @@ import ( ) type UserInternalAPI struct { - DB storage.Database - SyncProducer *producers.SyncAPI - Config *config.UserAPI + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase + SyncProducer *producers.SyncAPI + KeyChangeProducer *producers.KeyChange + Config *config.UserAPI DisableTLSValidation bool // AppServices is the list of all registered AS AppServices []config.ApplicationService - KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client - Cfg *config.UserAPI + FedClient fedsenderapi.KeyserverFederationAPI + Updater *DeviceListUpdater } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return fmt.Errorf("a.DB.SetDisplayName: %w", err) } - postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true res.Account = acc @@ -252,6 +254,17 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %s is not local", serverName) } + // If a device ID was specified, check if it already exists and + // avoid sending an empty device list update which would remove + // existing device keys. + isExisting := false + if req.DeviceID != nil && *req.DeviceID != "" { + existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + isExisting = existingDev.ID == *req.DeviceID + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -263,7 +276,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - if req.NoDeviceListUpdate { + if req.NoDeviceListUpdate || isExisting { return nil } // create empty device keys and upload them to trigger device list changes @@ -293,14 +306,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // Ask the keyserver to delete device keys and signatures for those devices - deleteReq := &keyapi.PerformDeleteKeysRequest{ + deleteReq := &api.PerformDeleteKeysRequest{ UserID: req.UserID, } for _, keyID := range req.DeviceIDs { deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } - deleteRes := &keyapi.PerformDeleteKeysResponse{} - if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + deleteRes := &api.PerformDeleteKeysResponse{} + if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { return err } if err := deleteRes.Error; err != nil { @@ -311,17 +324,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ + deviceKeys[i] = api.DeviceKeys{ UserID: userID, DeviceID: did, KeyJSON: nil, } } - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, }, &uploadRes); err != nil { @@ -385,10 +398,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf } if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ + DeviceKeys: []api.DeviceKeys{ { DeviceID: dev.ID, DisplayName: *req.DisplayName, diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go deleted file mode 100644 index 51b0fe3ef..000000000 --- a/userapi/inthttp/client.go +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// HTTP paths for the internal HTTP APIs -const ( - InputAccountDataPath = "/userapi/inputAccountData" - - PerformDeviceCreationPath = "/userapi/performDeviceCreation" - PerformAccountCreationPath = "/userapi/performAccountCreation" - PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" - PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" - PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" - PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" - PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" - PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" - PerformKeyBackupPath = "/userapi/performKeyBackup" - PerformPusherSetPath = "/pushserver/performPusherSet" - PerformPusherDeletionPath = "/pushserver/performPusherDeletion" - PerformPushRulesPutPath = "/pushserver/performPushRulesPut" - PerformSetAvatarURLPath = "/userapi/performSetAvatarURL" - PerformSetDisplayNamePath = "/userapi/performSetDisplayName" - PerformForgetThreePIDPath = "/userapi/performForgetThreePID" - PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation" - - QueryKeyBackupPath = "/userapi/queryKeyBackup" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" - QuerySearchProfilesPath = "/userapi/querySearchProfiles" - QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" - QueryPushersPath = "/pushserver/queryPushers" - QueryPushRulesPath = "/pushserver/queryPushRules" - QueryNotificationsPath = "/pushserver/queryNotifications" - QueryNumericLocalpartPath = "/userapi/queryNumericLocalpart" - QueryAccountAvailabilityPath = "/userapi/queryAccountAvailability" - QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" - QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" - QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" - QueryAccountByLocalpartPath = "/userapi/queryAccountType" -) - -// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewUserAPIClient( - apiURL string, - httpClient *http.Client, -) (api.UserInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewUserAPIClient: httpClient is ") - } - return &httpUserInternalAPI{ - apiURL: apiURL, - httpClient: httpClient, - }, nil -} - -type httpUserInternalAPI struct { - apiURL string - httpClient *http.Client -} - -func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { - return httputil.CallInternalRPCAPI( - "InputAccountData", h.apiURL+InputAccountDataPath, - h.httpClient, ctx, req, res, - ) -} - -func (h *httpUserInternalAPI) PerformAccountCreation( - ctx context.Context, - request *api.PerformAccountCreationRequest, - response *api.PerformAccountCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAccountCreation", h.apiURL+PerformAccountCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPasswordUpdate( - ctx context.Context, - request *api.PerformPasswordUpdateRequest, - response *api.PerformPasswordUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceCreation( - ctx context.Context, - request *api.PerformDeviceCreationRequest, - response *api.PerformDeviceCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceDeletion( - ctx context.Context, - request *api.PerformDeviceDeletionRequest, - response *api.PerformDeviceDeletionResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformLastSeenUpdate( - ctx context.Context, - request *api.PerformLastSeenUpdateRequest, - response *api.PerformLastSeenUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceUpdate( - ctx context.Context, - request *api.PerformDeviceUpdateRequest, - response *api.PerformDeviceUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformAccountDeactivation( - ctx context.Context, - request *api.PerformAccountDeactivationRequest, - response *api.PerformAccountDeactivationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformOpenIDTokenCreation( - ctx context.Context, - request *api.PerformOpenIDTokenCreationRequest, - response *api.PerformOpenIDTokenCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryProfile( - ctx context.Context, - request *api.QueryProfileRequest, - response *api.QueryProfileResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryProfile", h.apiURL+QueryProfilePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryDeviceInfos( - ctx context.Context, - request *api.QueryDeviceInfosRequest, - response *api.QueryDeviceInfosResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccessToken( - ctx context.Context, - request *api.QueryAccessTokenRequest, - response *api.QueryAccessTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccessToken", h.apiURL+QueryAccessTokenPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryDevices( - ctx context.Context, - request *api.QueryDevicesRequest, - response *api.QueryDevicesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDevices", h.apiURL+QueryDevicesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountData( - ctx context.Context, - request *api.QueryAccountDataRequest, - response *api.QueryAccountDataResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountData", h.apiURL+QueryAccountDataPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QuerySearchProfiles( - ctx context.Context, - request *api.QuerySearchProfilesRequest, - response *api.QuerySearchProfilesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryOpenIDToken( - ctx context.Context, - request *api.QueryOpenIDTokenRequest, - response *api.QueryOpenIDTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformKeyBackup( - ctx context.Context, - request *api.PerformKeyBackupRequest, - response *api.PerformKeyBackupResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformKeyBackup", h.apiURL+PerformKeyBackupPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryKeyBackup( - ctx context.Context, - request *api.QueryKeyBackupRequest, - response *api.QueryKeyBackupResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeyBackup", h.apiURL+QueryKeyBackupPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryNotifications( - ctx context.Context, - request *api.QueryNotificationsRequest, - response *api.QueryNotificationsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryNotifications", h.apiURL+QueryNotificationsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPusherSet( - ctx context.Context, - request *api.PerformPusherSetRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPusherSet", h.apiURL+PerformPusherSetPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPusherDeletion( - ctx context.Context, - request *api.PerformPusherDeletionRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryPushers( - ctx context.Context, - request *api.QueryPushersRequest, - response *api.QueryPushersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPushers", h.apiURL+QueryPushersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPushRulesPut( - ctx context.Context, - request *api.PerformPushRulesPutRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryPushRules( - ctx context.Context, - request *api.QueryPushRulesRequest, - response *api.QueryPushRulesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPushRules", h.apiURL+QueryPushRulesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) SetAvatarURL( - ctx context.Context, - request *api.PerformSetAvatarURLRequest, - response *api.PerformSetAvatarURLResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetAvatarURL", h.apiURL+PerformSetAvatarURLPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryNumericLocalpart( - ctx context.Context, - request *api.QueryNumericLocalpartRequest, - response *api.QueryNumericLocalpartResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountAvailability( - ctx context.Context, - request *api.QueryAccountAvailabilityRequest, - response *api.QueryAccountAvailabilityResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountByPassword( - ctx context.Context, - request *api.QueryAccountByPasswordRequest, - response *api.QueryAccountByPasswordResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) SetDisplayName( - ctx context.Context, - request *api.PerformUpdateDisplayNameRequest, - response *api.PerformUpdateDisplayNameResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryLocalpartForThreePID( - ctx context.Context, - request *api.QueryLocalpartForThreePIDRequest, - response *api.QueryLocalpartForThreePIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart( - ctx context.Context, - request *api.QueryThreePIDsForLocalpartRequest, - response *api.QueryThreePIDsForLocalpartResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformForgetThreePID( - ctx context.Context, - request *api.PerformForgetThreePIDRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( - ctx context.Context, - request *api.PerformSaveThreePIDAssociationRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountByLocalpart( - ctx context.Context, - req *api.QueryAccountByLocalpartRequest, - res *api.QueryAccountByLocalpartResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, - h.httpClient, ctx, req, res, - ) -} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go deleted file mode 100644 index 211b1b7a1..000000000 --- a/userapi/inthttp/client_logintoken.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -const ( - PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" - PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" - QueryLoginTokenPath = "/userapi/queryLoginToken" -) - -func (h *httpUserInternalAPI) PerformLoginTokenCreation( - ctx context.Context, - request *api.PerformLoginTokenCreationRequest, - response *api.PerformLoginTokenCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformLoginTokenDeletion( - ctx context.Context, - request *api.PerformLoginTokenDeletionRequest, - response *api.PerformLoginTokenDeletionResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryLoginToken( - ctx context.Context, - request *api.QueryLoginTokenRequest, - response *api.QueryLoginTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLoginToken", h.apiURL+QueryLoginTokenPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go deleted file mode 100644 index b40b507c2..000000000 --- a/userapi/inthttp/server.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// nolint: gocyclo -func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { - addRoutesLoginToken(internalAPIMux, s, enableMetrics) - - internalAPIMux.Handle( - PerformAccountCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", enableMetrics, s.PerformAccountCreation), - ) - - internalAPIMux.Handle( - PerformPasswordUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", enableMetrics, s.PerformPasswordUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", enableMetrics, s.PerformDeviceCreation), - ) - - internalAPIMux.Handle( - PerformLastSeenUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", enableMetrics, s.PerformLastSeenUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", enableMetrics, s.PerformDeviceUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", enableMetrics, s.PerformDeviceDeletion), - ) - - internalAPIMux.Handle( - PerformAccountDeactivationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", enableMetrics, s.PerformAccountDeactivation), - ) - - internalAPIMux.Handle( - PerformOpenIDTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", enableMetrics, s.PerformOpenIDTokenCreation), - ) - - internalAPIMux.Handle( - QueryProfilePath, - httputil.MakeInternalRPCAPI("UserAPIQueryProfile", enableMetrics, s.QueryProfile), - ) - - internalAPIMux.Handle( - QueryAccessTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", enableMetrics, s.QueryAccessToken), - ) - - internalAPIMux.Handle( - QueryDevicesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDevices", enableMetrics, s.QueryDevices), - ) - - internalAPIMux.Handle( - QueryAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", enableMetrics, s.QueryAccountData), - ) - - internalAPIMux.Handle( - QueryDeviceInfosPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", enableMetrics, s.QueryDeviceInfos), - ) - - internalAPIMux.Handle( - QuerySearchProfilesPath, - httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", enableMetrics, s.QuerySearchProfiles), - ) - - internalAPIMux.Handle( - QueryOpenIDTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", enableMetrics, s.QueryOpenIDToken), - ) - - internalAPIMux.Handle( - InputAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIInputAccountData", enableMetrics, s.InputAccountData), - ) - - internalAPIMux.Handle( - QueryKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", enableMetrics, s.QueryKeyBackup), - ) - - internalAPIMux.Handle( - PerformKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", enableMetrics, s.PerformKeyBackup), - ) - - internalAPIMux.Handle( - QueryNotificationsPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", enableMetrics, s.QueryNotifications), - ) - - internalAPIMux.Handle( - PerformPusherSetPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", enableMetrics, s.PerformPusherSet), - ) - - internalAPIMux.Handle( - PerformPusherDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", enableMetrics, s.PerformPusherDeletion), - ) - - internalAPIMux.Handle( - QueryPushersPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushers", enableMetrics, s.QueryPushers), - ) - - internalAPIMux.Handle( - PerformPushRulesPutPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", enableMetrics, s.PerformPushRulesPut), - ) - - internalAPIMux.Handle( - QueryPushRulesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", enableMetrics, s.QueryPushRules), - ) - - internalAPIMux.Handle( - PerformSetAvatarURLPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", enableMetrics, s.SetAvatarURL), - ) - - internalAPIMux.Handle( - QueryNumericLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", enableMetrics, s.QueryNumericLocalpart), - ) - - internalAPIMux.Handle( - QueryAccountAvailabilityPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", enableMetrics, s.QueryAccountAvailability), - ) - - internalAPIMux.Handle( - QueryAccountByPasswordPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", enableMetrics, s.QueryAccountByPassword), - ) - - internalAPIMux.Handle( - PerformSetDisplayNamePath, - httputil.MakeInternalRPCAPI("UserAPISetDisplayName", enableMetrics, s.SetDisplayName), - ) - - internalAPIMux.Handle( - QueryLocalpartForThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", enableMetrics, s.QueryLocalpartForThreePID), - ) - - internalAPIMux.Handle( - QueryThreePIDsForLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", enableMetrics, s.QueryThreePIDsForLocalpart), - ) - - internalAPIMux.Handle( - PerformForgetThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", enableMetrics, s.PerformForgetThreePID), - ) - - internalAPIMux.Handle( - PerformSaveThreePIDAssociationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), - ) - - internalAPIMux.Handle( - QueryAccountByLocalpartPath, - httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), - ) -} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go deleted file mode 100644 index dc116428b..000000000 --- a/userapi/inthttp/server_logintoken.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// addRoutesLoginToken adds routes for all login token API calls. -func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { - internalAPIMux.Handle( - PerformLoginTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", enableMetrics, s.PerformLoginTokenCreation), - ) - - internalAPIMux.Handle( - PerformLoginTokenDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", enableMetrics, s.PerformLoginTokenDeletion), - ) - - internalAPIMux.Handle( - QueryLoginTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", enableMetrics, s.QueryLoginToken), - ) -} diff --git a/keyserver/producers/keychange.go b/userapi/producers/keychange.go similarity index 94% rename from keyserver/producers/keychange.go rename to userapi/producers/keychange.go index f86c34177..da6cea31a 100644 --- a/keyserver/producers/keychange.go +++ b/userapi/producers/keychange.go @@ -18,9 +18,9 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -28,8 +28,8 @@ import ( // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { Topic string - JetStream nats.JetStreamContext - DB storage.Database + JetStream JetStreamPublisher + DB storage.KeyChangeDatabase } // ProduceKeyChanges creates new change events for each key diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa9856..165de8994 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,13 +19,13 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658f..278378861 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d74..057160374 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_keys_table.go rename to userapi/storage/postgres/cross_signing_keys_table.go index 1022157e8..c0ecbd303 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_sigs_table.go rename to userapi/storage/postgres/cross_signing_sigs_table.go index 4536b7d80..b0117145c 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022012016470000_key_changes.go rename to userapi/storage/postgres/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go similarity index 87% rename from keyserver/storage/postgres/device_keys_table.go rename to userapi/storage/postgres/device_keys_table.go index 2aa11c520..a9203857f 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/userapi/storage/postgres/device_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -92,31 +92,16 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 2dd216189..88f8839c5 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -81,7 +81,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -96,7 +96,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,19 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil +} + +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + return s.InsertDevice(ctx, txn, id, localpart, serverName, accessToken, displayName, ipAddr, userAgent) } // deleteDevice removes a single device by id and user localpart. @@ -271,7 +283,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { @@ -303,7 +315,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7bae..91a34c357 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go similarity index 90% rename from keyserver/storage/postgres/key_changes_table.go rename to userapi/storage/postgres/key_changes_table.go index c0e3429c7..a00494140 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/userapi/storage/postgres/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -66,7 +66,10 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { if err = executeMigration(context.Background(), db); err != nil { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -95,16 +98,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go similarity index 89% rename from keyserver/storage/postgres/one_time_keys_table.go rename to userapi/storage/postgres/one_time_keys_table.go index 2117efcae..972a59147 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -49,7 +49,7 @@ const upsertKeysSQL = "" + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" const selectKeysCountSQL = "" + @@ -84,25 +84,14 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go similarity index 98% rename from keyserver/storage/postgres/stale_device_lists.go rename to userapi/storage/postgres/stale_device_lists.go index 248ddfb45..c823b58c6 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/userapi/storage/postgres/stale_device_lists.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 92dc48081..673d123ba 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f549dcef9..d3272a032 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -588,16 +599,42 @@ func (d *Database) CreateDevice( deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { - return err - } + _, ok := d.Writer.(*sqlutil.ExclusiveWriter) + if ok { // we're using most likely using SQLite, so do things a little different + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) - return err - }) + devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + // No devices yet, only create a new one + if len(devices) == 0 { + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + } + sessionID := devices[0].SessionID + 1 + + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + // Create a new device with the session ID incremented + dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID) + return err + }) + } else { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + }) + } } else { // We generate device IDs in a loop in case its already taken. // We cap this at going round 5 times to ensure we don't spin forever @@ -618,7 +655,7 @@ func (d *Database) CreateDevice( } } } - return + return dev, returnErr } // generateDeviceID creates a new device id. Returns an error if failed to generate @@ -849,3 +886,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_keys_table.go rename to userapi/storage/sqlite3/cross_signing_keys_table.go index e103d9883..10721fcc8 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_sigs_table.go rename to userapi/storage/sqlite3/cross_signing_sigs_table.go index 7a153e8fb..2be00c9c1 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go rename to userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go similarity index 88% rename from keyserver/storage/sqlite3/device_keys_table.go rename to userapi/storage/sqlite3/device_keys_table.go index 73768da5b..15e69cc4c 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -86,28 +86,16 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index c5db34bd7..65e17527d 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -65,7 +65,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -80,7 +80,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,36 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil +} + +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + return nil, err + } + dev := &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, serverName), + AccessToken: accessToken, + SessionID: sessionID, + LastSeenTS: createdTimeMS, + LastSeenIP: ipAddr, + UserAgent: userAgent, + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -181,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart @@ -271,7 +301,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } @@ -317,7 +347,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb19..ed2746310 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go similarity index 90% rename from keyserver/storage/sqlite3/key_changes_table.go rename to userapi/storage/sqlite3/key_changes_table.go index 0c844d67a..923bb57eb 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -65,7 +65,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -93,16 +96,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go similarity index 89% rename from keyserver/storage/sqlite3/one_time_keys_table.go rename to userapi/storage/sqlite3/one_time_keys_table.go index 7a923d0e5..a992d399c 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -48,7 +48,7 @@ const upsertKeysSQL = "" + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" const selectKeysCountSQL = "" + @@ -83,25 +83,14 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go similarity index 98% rename from keyserver/storage/sqlite3/stale_device_lists.go rename to userapi/storage/sqlite3/stale_device_lists.go index fd76a6e3b..f078fc99f 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c944..72b3ba49d 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f7063..0f3eeed1b 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e752..0329fb46a 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23aafff03..f52e7e17d 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e6..163e3e173 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index e14776cf3..693e73038 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -44,6 +44,7 @@ type AccountsTable interface { type DevicesTable interface { InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error @@ -144,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go similarity index 94% rename from keyserver/storage/tables/stale_device_lists_test.go rename to userapi/storage/tables/stale_device_lists_test.go index 76d3baddd..b9bdafdaa 100644 --- a/keyserver/storage/tables/stale_device_lists_test.go +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -4,15 +4,15 @@ import ( "context" "testing" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { diff --git a/keyserver/types/storage.go b/userapi/types/storage.go similarity index 100% rename from keyserver/types/storage.go rename to userapi/types/storage.go diff --git a/userapi/userapi.go b/userapi/userapi.go index 183ca3123..826bd7213 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,40 +17,34 @@ package userapi import ( "time" - "github.com/gorilla/mux" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal/pushgateway" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" - "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/util" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI, enableMetrics bool) { - inthttp.AddRoutes(router, intAPI, enableMetrics) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers +// NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, - rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, -) api.UserInternalAPI { + base *base.BaseDendrite, + rsAPI rsapi.UserRoomserverAPI, + fedClient fedsenderapi.KeyserverFederationAPI, +) *internal.UserInternalAPI { + cfg := &base.Cfg.UserAPI js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + appServices := base.Cfg.Derived.ApplicationServices - db, err := storage.NewUserAPIDatabase( + pgClient := base.PushGatewayHTTPClient() + + db, err := storage.NewUserDatabase( base, &cfg.AccountDatabase, cfg.Matrix.ServerName, @@ -63,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -72,17 +71,50 @@ func NewInternalAPI( cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), ) + keyChangeProducer := &producers.KeyChange{ + Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + JetStream: js, + DB: keyDB, + } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, Config: cfg, AppServices: appServices, - KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, - Cfg: cfg, + FedClient: fedClient, + } + + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + userAPI.Updater = updater + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() + + dlConsumer := consumers.NewDeviceListUpdateConsumer( + base.ProcessContext, cfg, js, updater, + ) + if err := dlConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start device list consumer") + } + + sigConsumer := consumers.NewSigningKeyUpdateConsumer( + base.ProcessContext, cfg, js, userAPI, + ) + if err := sigConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index dada56de4..01e491cb6 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -17,23 +17,22 @@ package userapi_test import ( "context" "fmt" - "net/http" "reflect" + "sync" "testing" "time" - "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" - "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -43,32 +42,71 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +type dummyProducer struct { + callCount sync.Map + t *testing.T +} + +func (d *dummyProducer) PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) { + count, loaded := d.callCount.LoadOrStore(msg.Subject, 1) + if loaded { + c, ok := count.(int) + if !ok { + d.t.Fatalf("unexpected type: %T with value %q", c, c) + } + d.callCount.Store(msg.Subject, c+1) + d.t.Logf("Incrementing call counter for %s", msg.Subject) + } + return &nats.PubAck{}, nil +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, publisher producers.JetStreamPublisher) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } + accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + if publisher == nil { + publisher = &dummyProducer{t: t} + } + + syncProducer := producers.NewSyncAPI(accountDB, publisher, "client_data", "notification_data") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: publisher, Topic: "keychange"} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + KeyDatabase: keyDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -129,7 +167,7 @@ func TestQueryProfile(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { @@ -142,20 +180,7 @@ func TestQueryProfile(t *testing.T) { t.Fatalf("failed to set display name: %s", err) } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI, false) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client") - } - runCases(httpAPI, true) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(userAPI, false) - }) + runCases(userAPI, false) }) } @@ -165,7 +190,7 @@ func TestQueryProfile(t *testing.T) { func TestPasswordlessLoginFails(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { @@ -191,7 +216,7 @@ func TestLoginToken(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { @@ -241,7 +266,7 @@ func TestLoginToken(t *testing.T) { t.Run("expiredTokenIsNotReturned", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -266,7 +291,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteWorks", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -297,7 +322,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteUnknownIsNoOp", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} var dresp api.PerformLoginTokenDeletionResponse @@ -315,7 +340,7 @@ func TestQueryAccountByLocalpart(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) @@ -347,24 +372,306 @@ func TestQueryAccountByLocalpart(t *testing.T) { } } - t.Run("Monolith", func(t *testing.T) { - testCases(t, intAPI) - // also test tracing - testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) - }) + testCases(t, intAPI) + }) +} - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, intAPI, false) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) - userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) - if err != nil { - t.Fatalf("failed to create HTTP client: %s", err) + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +} + +func TestDevices(t *testing.T) { + ctx := context.Background() + + dupeAccessToken := util.RandomString(8) + + displayName := "testing" + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + + newDisplayName := "new name" + if tc.inputData.DeviceDisplayName == nil { + updateRes := api.PerformDeviceUpdateResponse{} + updateReq := api.PerformDeviceUpdateRequest{ + RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"), + DeviceID: deviceID, + DisplayName: &newDisplayName, + } + + if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil { + t.Fatal(err) + } + } + + queryDeviceInfosRes := api.QueryDeviceInfosResponse{} + queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}} + if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil { + t.Fatal(err) + } + gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName + if tc.inputData.DeviceDisplayName != nil { + wantDisplayName := *tc.inputData.DeviceDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } else { + wantDisplayName := newDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + publisher := &dummyProducer{t: t} + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, publisher) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + // Set NoDeviceListUpdate to false, to verify we don't send device list updates when + // reusing the same device ID + req.NoDeviceListUpdate = false + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + + publisher.callCount.Range(func(key, value any) bool { + if value != nil { + t.Fatalf("expected publisher to not get called, but got value %d for subject %s", value, key) } - testCases(t, userHTTPApi) - + return true }) }) } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc7999..31617d8c1 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39bf..08d1371d6 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index f1d20259c..421852d3f 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { defer close() base, _, _ := testrig.Base(nil) defer base.Close() - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 42c8f5d7c..21035e045 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -55,7 +55,7 @@ func StartPhoneHomeCollector(startTime time.Time, cfg *config.Dendrite, statsDB serverName: cfg.Global.ServerName, cfg: cfg, db: statsDB, - isMonolith: cfg.IsMonolith, + isMonolith: true, client: &http.Client{ Timeout: time.Second * 30, Transport: http.DefaultTransport, diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 6e62210e8..5f626b5bc 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -21,7 +21,7 @@ func TestCollect(t *testing.T) { b, _, _ := testrig.Base(nil) connStr, closeDB := test.PrepareDBConnectionString(t, dbType) defer closeDB() - db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil {