diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md index f40c56609..a3b8c0754 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md @@ -17,7 +17,6 @@ see: https://www.matrix.org/security-disclosure-policy/ ### Background information - **Dendrite version or git SHA**: -- **Monolith or Polylith?**: - **SQLite3 or Postgres?**: - **Running in Docker?**: - **`go version`**: diff --git a/.github/codecov.yaml b/.github/codecov.yaml new file mode 100644 index 000000000..78122c990 --- /dev/null +++ b/.github/codecov.yaml @@ -0,0 +1,20 @@ +flag_management: + default_rules: + carryforward: true + +coverage: + status: + project: + default: + target: auto + threshold: 0% + base: auto + flags: + - unittests + patch: + default: + target: 75% + threshold: 0% + base: auto + flags: + - unittests \ No newline at end of file diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 55e4b354f..fa8fcfdaa 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -4,7 +4,15 @@ on: push: branches: - main + paths: + - '**.go' # only execute on changes to go files + - 'go.sum' # or dependency updates + - '.github/workflows/**' # or workflow changes pull_request: + paths: + - '**.go' + - 'go.sum' # or dependency updates + - '.github/workflows/**' release: types: [published] workflow_dispatch: @@ -25,7 +33,7 @@ jobs: - name: Install Go uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: "stable" cache: true - name: Install Node @@ -62,14 +70,14 @@ jobs: - name: Install Go uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: "stable" - name: golangci-lint uses: golangci/golangci-lint-action@v3 # run go test with go 1.19 test: - timeout-minutes: 5 - name: Unit tests (Go ${{ matrix.go }}) + timeout-minutes: 10 + name: Unit tests runs-on: ubuntu-latest # Service containers to run with `container-job` services: @@ -91,17 +99,21 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.19"] steps: - uses: actions/checkout@v3 - name: Setup go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} - cache: true + go-version: "stable" + - uses: actions/cache@v3 + # manually set up caches, as they otherwise clash with different steps using setup-go with cache=true + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-stable-unit-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-stable-unit- - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: @@ -122,7 +134,6 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.19"] goos: ["linux"] goarch: ["amd64"] steps: @@ -130,15 +141,15 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}- - name: Install dependencies x86 if: ${{ matrix.goarch == '386' }} run: sudo apt update && sudo apt-get install -y gcc-multilib @@ -156,23 +167,22 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.18", "1.19"] goos: ["windows"] goarch: ["amd64"] steps: - uses: actions/checkout@v3 - - name: Setup Go ${{ matrix.go }} + - name: Setup Go uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: "stable" - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go-stable-${{ matrix.goos }}-${{ matrix.goarch }}- - name: Install dependencies run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - env: @@ -194,6 +204,63 @@ jobs: with: jobs: ${{ toJSON(needs) }} + # run go test with different go versions + integration: + timeout-minutes: 20 + needs: initial-tests-done + name: Integration tests + runs-on: ubuntu-latest + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres:13-alpine + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + ports: + # Maps tcp port 5432 on service container to the host + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: + - uses: actions/checkout@v3 + - name: Setup go + uses: actions/setup-go@v3 + with: + go-version: "stable" + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-stable-test-race-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-stable-test-race- + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt + env: + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + flags: unittests + fail_ci_if_error: false + # run database upgrade tests upgrade_test: name: Upgrade tests @@ -205,7 +272,7 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: "1.18" + go-version: "stable" cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests @@ -223,7 +290,7 @@ jobs: - name: Setup go uses: actions/setup-go@v3 with: - go-version: "1.18" + go-version: "stable" cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests @@ -243,18 +310,14 @@ jobs: - label: PostgreSQL postgres: postgres - - label: PostgreSQL, full HTTP APIs - postgres: postgres - api: full-http container: - image: matrixdotorg/sytest-dendrite:latest + image: matrixdotorg/sytest-dendrite volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} - API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} CGO_ENABLED: ${{ matrix.cgo && 1 }} steps: @@ -302,11 +365,6 @@ jobs: - label: PostgreSQL postgres: Postgres cgo: 0 - - - label: PostgreSQL, full HTTP APIs - postgres: Postgres - api: full-http - cgo: 0 steps: # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path @@ -349,7 +407,7 @@ jobs: (wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done # Build initial Dendrite image - - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite env: DOCKER_BUILDKIT: 1 @@ -361,8 +419,8 @@ jobs: shell: bash name: Run Complement Tests env: - COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} - API: ${{ matrix.api && 1 }} + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ working-directory: complement integration-tests-done: @@ -372,6 +430,7 @@ jobs: initial-tests-done, sytest, complement, + integration ] runs-on: ubuntu-latest if: ${{ !cancelled() }} # Run this even if prior jobs were skipped diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 2e17539d8..0c3053a56 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -61,7 +61,6 @@ jobs: cache-to: type=gha,mode=max context: . build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: monolith platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -77,7 +76,6 @@ jobs: cache-to: type=gha,mode=max context: . build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: monolith platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -98,86 +96,6 @@ jobs: with: sarif_file: "trivy-results.sarif" - polylith: - name: Polylith image - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - security-events: write # To upload Trivy sarif files - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Get release tag & build flags - if: github.event_name == 'release' # Only for GitHub releases - run: | - echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV - BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) - [ ${BRANCH} == "main" ] && BRANCH="" - echo "BRANCH=${BRANCH}" >> $GITHUB_ENV - - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Login to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ env.DOCKER_HUB_USER }} - password: ${{ secrets.DOCKER_TOKEN }} - - name: Login to GitHub Containers - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Build main polylith image - if: github.ref_name == 'main' - id: docker_build_polylith - uses: docker/build-push-action@v3 - with: - cache-from: type=gha - cache-to: type=gha,mode=max - context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: polylith - platforms: ${{ env.PLATFORMS }} - push: true - tags: | - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - - - name: Build release polylith image - if: github.event_name == 'release' # Only for GitHub releases - id: docker_build_polylith_release - uses: docker/build-push-action@v3 - with: - cache-from: type=gha - cache-to: type=gha,mode=max - context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} - target: polylith - platforms: ${{ env.PLATFORMS }} - push: true - tags: | - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:latest - ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest - ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} - - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - demo-pinecone: name: Pinecone demo image runs-on: ubuntu-latest diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml new file mode 100644 index 000000000..9df3cceae --- /dev/null +++ b/.github/workflows/gh-pages.yml @@ -0,0 +1,52 @@ +# Sample workflow for building and deploying a Jekyll site to GitHub Pages +name: Deploy GitHub Pages dependencies preinstalled + +on: + # Runs on pushes targeting the default branch + push: + branches: ["gh-pages"] + paths: + - 'docs/**' # only execute if we have docs changes + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Pages + uses: actions/configure-pages@v2 + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./docs + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1 diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml new file mode 100644 index 000000000..7cdc369ba --- /dev/null +++ b/.github/workflows/helm.yml @@ -0,0 +1,39 @@ +name: Release Charts + +on: + push: + branches: + - main + paths: + - 'helm/**' # only execute if we have helm chart changes + +jobs: + release: + # depending on default permission settings for your org (contents being read-only or read-write for workloads), you will have to add permissions + # see: https://docs.github.com/en/actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure Git + run: | + git config user.name "$GITHUB_ACTOR" + git config user.email "$GITHUB_ACTOR@users.noreply.github.com" + + - name: Install Helm + uses: azure/setup-helm@v3 + with: + version: v3.10.0 + + - name: Run chart-releaser + uses: helm/chart-releaser-action@v1.4.1 + env: + CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + with: + config: helm/cr.yaml + charts_dir: helm/ diff --git a/.github/workflows/k8s.yml b/.github/workflows/k8s.yml new file mode 100644 index 000000000..fc5e8c906 --- /dev/null +++ b/.github/workflows/k8s.yml @@ -0,0 +1,90 @@ +name: k8s + +on: + push: + branches: ["main"] + paths: + - 'helm/**' # only execute if we have helm chart changes + pull_request: + branches: ["main"] + paths: + - 'helm/**' + +jobs: + lint: + name: Lint Helm chart + runs-on: ubuntu-latest + outputs: + changed: ${{ steps.list-changed.outputs.changed }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: azure/setup-helm@v3 + with: + version: v3.10.0 + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + check-latest: true + - uses: helm/chart-testing-action@v2.3.1 + - name: Get changed status + id: list-changed + run: | + changed=$(ct list-changed --config helm/ct.yaml --target-branch ${{ github.event.repository.default_branch }}) + if [[ -n "$changed" ]]; then + echo "::set-output name=changed::true" + fi + + - name: Run lint + run: ct lint --config helm/ct.yaml + + # only bother to run if lint step reports a change to the helm chart + install: + needs: + - lint + if: ${{ needs.lint.outputs.changed == 'true' }} + name: Install Helm charts + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ inputs.checkoutCommit }} + - name: Install Kubernetes tools + uses: yokawasa/action-setup-kube-tools@v0.8.2 + with: + setup-tools: | + helmv3 + helm: "3.10.3" + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Set up chart-testing + uses: helm/chart-testing-action@v2.3.1 + - name: Create k3d cluster + uses: nolar/setup-k3d-k3s@v1 + with: + version: v1.21 + - name: Remove node taints + run: | + kubectl taint --all=true nodes node.cloudprovider.kubernetes.io/uninitialized- || true + - name: Run chart-testing (install) + run: ct install --config helm/ct.yaml + + # Install the chart using helm directly and test with create-account + - name: Install chart + run: | + helm install --values helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml dendrite helm/dendrite + - name: Wait for Postgres and Dendrite to be up + run: | + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=postgresql || kubectl get pods -A + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=dendrite || kubectl get pods -A + kubectl get pods -A + kubectl get services + kubectl get ingress + - name: Run create account + run: | + podName=$(kubectl get pods -l app.kubernetes.io/name=dendrite -o name) + kubectl exec "${podName}" -- /usr/bin/create-account -username alice -password somerandompassword \ No newline at end of file diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ff4d47187..dff9b34c9 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -10,107 +10,43 @@ concurrency: cancel-in-progress: true jobs: - # run go test with different go versions - test: - timeout-minutes: 20 - name: Unit tests (Go ${{ matrix.go }}) - runs-on: ubuntu-latest - # Service containers to run with `container-job` - services: - # Label used to access the service container - postgres: - # Docker Hub image - image: postgres:13-alpine - # Provide the password for postgres - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - ports: - # Maps tcp port 5432 on service container to the host - - 5432:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.18", "1.19"] - steps: - - uses: actions/checkout@v3 - - name: Setup go - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go }} - - name: Set up gotestfmt - uses: gotesttools/gotestfmt-action@v2 - with: - # Optional: pass GITHUB_TOKEN to avoid rate limiting. - token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt - env: - POSTGRES_HOST: localhost - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - - # Dummy step to gate other tests on without repeating the whole list - initial-tests-done: - name: Initial tests passed - needs: [test] - runs-on: ubuntu-latest - if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - steps: - - name: Check initial tests passed - uses: re-actors/alls-green@release/v1 - with: - jobs: ${{ toJSON(needs) }} - # run Sytest in different variations sytest: timeout-minutes: 60 - needs: initial-tests-done name: "Sytest (${{ matrix.label }})" runs-on: ubuntu-latest strategy: fail-fast: false matrix: include: - - label: SQLite + - label: SQLite native - - label: SQLite, full HTTP APIs - api: full-http + - label: SQLite Cgo + cgo: 1 - label: PostgreSQL postgres: postgres - - - label: PostgreSQL, full HTTP APIs - postgres: postgres - api: full-http container: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} - API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} RACE_DETECTION: 1 + COVER: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src @@ -133,3 +69,192 @@ jobs: path: | /logs/results.tap /logs/**/*.log* + + sytest-coverage: + timeout-minutes: 5 + name: "Sytest Coverage" + runs-on: ubuntu-latest + needs: sytest # only run once Sytest is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 'stable' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'integrationcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > sytest.cov + go tool cover -func=sytest.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./sytest.cov + flags: sytest + fail_ci_if_error: true + + # run Complement + complement: + name: "Complement (${{ matrix.label }})" + timeout-minutes: 60 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - label: SQLite native + cgo: 0 + + - label: SQLite Cgo + cgo: 1 + + - label: PostgreSQL + postgres: Postgres + cgo: 0 + steps: + # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. + # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path + - name: "Set Go Version" + run: | + echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH + echo "~/go/bin" >> $GITHUB_PATH + - name: "Install Complement Dependencies" + # We don't need to install Go because it is included on the Ubuntu 20.04 image: + # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 + run: | + sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev + go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + - name: Run actions/checkout@v3 for dendrite + uses: actions/checkout@v3 + with: + path: dendrite + + # Attempt to check out the same branch of Complement as the PR. If it + # doesn't exist, fallback to main. + - name: Checkout complement + shell: bash + run: | + mkdir -p complement + # Attempt to use the version of complement which best matches the current + # build. Depending on whether this is a PR or release, etc. we need to + # use different fallbacks. + # + # 1. First check if there's a similarly named branch (GITHUB_HEAD_REF + # for pull requests, otherwise GITHUB_REF). + # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y + # (GITHUB_BASE_REF for pull requests). + # 3. Use the default complement branch ("master"). + for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "master"; do + # Skip empty branch names and merge commits. + if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then + continue + fi + (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break + done + # Build initial Dendrite image + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + working-directory: dendrite + env: + DOCKER_BUILDKIT: 1 + + - name: Create post test script + run: | + cat < /tmp/posttest.sh + #!/bin/bash + mkdir -p /tmp/Complement/logs/\$2/\$1/ + docker cp \$1:/dendrite/complementcover.log /tmp/Complement/logs/\$2/\$1/ + EOF + + chmod +x /tmp/posttest.sh + # Run Complement + - run: | + set -o pipefail && + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + shell: bash + name: Run Complement Tests + env: + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.cgo }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ + COMPLEMENT_DENDRITE_COVER: 1 + COMPLEMENT_POST_TEST_SCRIPT: /tmp/posttest.sh + working-directory: complement + + - name: Upload Complement logs + uses: actions/upload-artifact@v2 + if: ${{ always() }} + with: + name: Complement Logs - (Dendrite, ${{ join(matrix.*, ', ') }}) + path: | + /tmp/Complement/**/complementcover.log + + complement-coverage: + timeout-minutes: 5 + name: "Complement Coverage" + runs-on: ubuntu-latest + needs: complement # only run once Complement is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 'stable' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'complementcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > complement.cov + go tool cover -func=complement.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./complement.cov + flags: complement + fail_ci_if_error: true + + element_web: + timeout-minutes: 120 + runs-on: ubuntu-latest + steps: + - uses: tecolicom/actions-use-apt-tools@v1 + with: + # Our test suite includes some screenshot tests with unusual diacritics, which are + # supposed to be covered by STIXGeneral. + tools: fonts-stix + - uses: actions/checkout@v2 + with: + repository: matrix-org/matrix-react-sdk + - uses: actions/setup-node@v3 + with: + cache: 'yarn' + - name: Fetch layered build + run: scripts/ci/layered.sh + - name: Copy config + run: cp element.io/develop/config.json config.json + working-directory: ./element-web + - name: Build + env: + CI_PACKAGE: true + NODE_OPTIONS: "--openssl-legacy-provider" + run: yarn build + working-directory: ./element-web + - name: Edit Test Config + run: | + sed -i '/HOMESERVER/c\ HOMESERVER: "dendrite",' cypress.config.ts + - name: "Run cypress tests" + uses: cypress-io/github-action@v4.1.1 + with: + browser: chrome + start: npx serve -p 8080 ./element-web/webapp + wait-on: 'http://localhost:8080' + env: + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true + TMPDIR: ${{ runner.temp }} diff --git a/.gitignore b/.gitignore index 662d3ae97..9f4212999 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,7 @@ dendrite.yaml # Database files *.db +*.db-journal # Log files *.log* diff --git a/.vscode/launch.json b/.vscode/launch.json index 6142a8df0..715b42250 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,7 +5,7 @@ "type": "go", "request": "launch", "mode": "auto", - "program": "${workspaceFolder}/cmd/dendrite-monolith-server", + "program": "${workspaceFolder}/cmd/dendrite", "args": [ "-really-enable-open-registration", "-config", diff --git a/CHANGES.md b/CHANGES.md index f5a82cfe2..8052efd8a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,80 @@ # Changelog +## Dendrite 0.12.0 (2023-03-13) + +### Features + +- The userapi and keyserver have been merged (no actions needed regarding the database) +- The internal NATS JetStream server is now using logrus for logging (contributed by [dvob](https://github.com/dvob)) +- The roomserver database has been refactored to have separate interfaces when working with rooms and events. Also includes increased usage of the cache to avoid database round trips. (database is unchanged) +- The pinecone demo now shuts down more cleanly +- The Helm chart now has the ability to deploy a Grafana chart as well (contributed by [genofire](https://github.com/genofire)) +- Support for listening on unix sockets has been added (contributed by [cyberb](https://github.com/cyberb)) +- The internal NATS server was updated to v2.9.15 +- Initial support for `runtime/trace` has been added, to further track down long-running tasks + +### Fixes + +- The `session_id` is now correctly set when using SQLite +- An issue where device keys could be removed if a device ID is reused has been fixed +- A possible DoS issue related to relations has been fixed (reported by [sleroq](https://github.com/sleroq)) +- When backfilling events, errors are now ignored if we still could fetch events + +### Other + +- **⚠️ DEPRECATION: Polylith/HTTP API mode has been removed** +- The default endpoint to report usages stats to has been updated + +## Dendrite 0.11.1 (2023-02-10) + +**⚠️ DEPRECATION WARNING: This is the last release to have polylith and HTTP API mode. Future releases are monolith only.** + +### Features + +* Dendrite can now be compiled against Go 1.20 +* Initial store and forward support has been added +* A landing page showing that Dendrite is running has been added (contributed by [LukasLJL](https://github.com/LukasLJL)) + +### Fixes + +- `/sync` is now using significantly less database round trips when using Postgres, resulting in faster initial syncs, allowing larger accounts to login again +- Many under the hood pinecone improvements +- Publishing rooms is now possible again + +## Dendrite 0.11.0 (2023-01-20) + +The last three missing federation API Sytests have been fixed - bringing us to 100% server-server Synapse parity, with client-server parity at 93% 🎉 + +### Features + +* Added `/_dendrite/admin/purgeRoom/{roomID}` to clean up the database +* The default room version was updated to 10 (contributed by [FSG-Cat](https://github.com/FSG-Cat)) + +### Fixes + +* An oversight in the `create-config` binary, which now correctly sets the media path if specified (contributed by [BieHDC](https://github.com/BieHDC)) +* The Helm chart now uses the `$.Chart.AppVersion` as the default image version to pull, with the possibility to override it (contributed by [genofire](https://github.com/genofire)) + +## Dendrite 0.10.9 (2023-01-17) + +### Features + +* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore +* Dendrite now has its own Helm chart +* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation) + +### Fixes + +* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts` +* Outgoing presence will now correctly be sent to newly joined hosts +* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable +* Federated backfilling for medium/large rooms has been fixed +* `/login` causing wrong device list updates has been resolved +* `/sync` should now return the correct room summary heroes +* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly +* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug)) +* `/sync` has been optimised to only query state events for history visibility if they are really needed + ## Dendrite 0.10.8 (2022-11-29) ### Features diff --git a/Dockerfile b/Dockerfile index 2f251dee8..d0db9de49 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,26 +1,31 @@ -FROM docker.io/golang:1.19-alpine AS base - -RUN apk --update --no-cache add bash build-base +# +# base installs required dependencies and runs go mod download to cache dependencies +# +FROM --platform=${BUILDPLATFORM} docker.io/golang:1.20-alpine AS base +RUN apk --update --no-cache add bash build-base curl WORKDIR /build COPY . /build -RUN mkdir -p bin -RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server -RUN go build -trimpath -o bin/ ./cmd/create-account -RUN go build -trimpath -o bin/ ./cmd/generate-keys - +# +# Builds the Dendrite image containing all required binaries +# FROM alpine:latest -LABEL org.opencontainers.image.title="Dendrite (Monolith)" +RUN apk --update --no-cache add curl +LABEL org.opencontainers.image.title="Dendrite" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" -COPY --from=base /build/bin/* /usr/bin/ +COPY --from=build /out/create-account /usr/bin/create-account +COPY --from=build /out/generate-config /usr/bin/generate-config +COPY --from=build /out/generate-keys /usr/bin/generate-keys +COPY --from=build /out/dendrite /usr/bin/dendrite VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/dendrite"] +EXPOSE 8008 8448 \ No newline at end of file diff --git a/README.md b/README.md index dfef11bae..295203eb4 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ This does not mean: - Dendrite is ready for massive homeserver deployments. There is no sharding of microservices (although it is possible to run them on separate machines) and there is no high-availability/clustering support. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. -In the future, we will be able to scale up to gigantic servers (equivalent to `matrix.org`) via polylith mode. If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or join us in: @@ -72,10 +71,10 @@ $ ./bin/generate-keys --tls-cert server.crt --tls-key server.key # Copy and modify the config file - you'll need to set a server name and paths to the keys # at the very least, along with setting up the database connection strings. -$ cp dendrite-sample.monolith.yaml dendrite.yaml +$ cp dendrite-sample.yaml dendrite.yaml # Build and run the server: -$ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +$ ./bin/dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml # Create an user account (add -admin for an admin user). # Specify the localpart only, e.g. 'alice' for '@alice:domain.com' @@ -88,7 +87,7 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it -updates with CI. As of August 2022 we're at around 90% CS API coverage and 95% Federation coverage, though check +updates with CI. As of January 2023, we have 100% server-server parity with Synapse, and the client-server parity is at 93% , though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse servers such as matrix.org reasonably well, although there are still some missing features (like SSO and Third-party ID APIs). diff --git a/appservice/appservice.go b/appservice/appservice.go index b3c28dbde..5b1b93de2 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -21,30 +21,24 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/sirupsen/logrus" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" - "github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/query" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) -// AddInternalRoutes registers HTTP handlers for internal API calls -func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI) { - inthttp.AddRoutes(queryAPI, router) -} - // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { client := &http.Client{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go new file mode 100644 index 000000000..de9f5aaf1 --- /dev/null +++ b/appservice/appservice_test.go @@ -0,0 +1,203 @@ +package appservice_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "regexp" + "strings" + "testing" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi" + + "github.com/matrix-org/dendrite/test/testrig" +) + +func TestAppserviceInternalAPI(t *testing.T) { + + // Set expected results + existingProtocol := "irc" + wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} + wantProtocolResult := map[string]api.ASProtocolResponse{ + existingProtocol: wantProtocolResponse, + } + + // create a dummy AS url, handling some cases + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "location"): + // Check if we've got an existing protocol, if so, return a proper response. + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "user"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "protocol"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode(nil); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + default: + t.Logf("hit location: %s", r.URL.Path) + } + })) + + // The test cases to run + runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { + t.Run("UserIDExists", func(t *testing.T) { + testUserIDExists(t, testAPI, "@as-testing:test", true) + testUserIDExists(t, testAPI, "@as1-testing:test", false) + }) + + t.Run("AliasExists", func(t *testing.T) { + testAliasExists(t, testAPI, "@asroom-testing:test", true) + testAliasExists(t, testAPI, "@asroom1-testing:test", false) + }) + + t.Run("Locations", func(t *testing.T) { + testLocations(t, testAPI, existingProtocol, wantLocationResponse) + testLocations(t, testAPI, "abc", nil) + }) + + t.Run("User", func(t *testing.T) { + testUser(t, testAPI, existingProtocol, wantUserResponse) + testUser(t, testAPI, "abc", nil) + }) + + t.Run("Protocols", func(t *testing.T) { + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache + testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols + testProtocol(t, testAPI, "abc", nil) + }) + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + defer closeBase() + + // Create a dummy application service + base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ + { + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + }, + } + + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(base) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + + runCases(t, asAPI) + }) +} + +func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { + ctx := context.Background() + userResp := &api.UserIDExistsResponse{} + + if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{ + UserID: userID, + }, userResp); err != nil { + t.Errorf("failed to get userID: %s", err) + } + if userResp.UserIDExists != wantExists { + t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists) + } +} + +func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) { + ctx := context.Background() + aliasResp := &api.RoomAliasExistsResponse{} + + if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{ + Alias: alias, + }, aliasResp); err != nil { + t.Errorf("failed to get alias: %s", err) + } + if aliasResp.AliasExists != wantExists { + t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists) + } +} + +func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) { + ctx := context.Background() + locationResp := &api.LocationResponse{} + + if err := asAPI.Locations(ctx, &api.LocationRequest{ + Protocol: proto, + }, locationResp); err != nil { + t.Errorf("failed to get locations: %s", err) + } + if !reflect.DeepEqual(locationResp.Locations, wantResult) { + t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult) + } +} + +func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) { + ctx := context.Background() + userResp := &api.UserResponse{} + + if err := asAPI.User(ctx, &api.UserRequest{ + Protocol: proto, + }, userResp); err != nil { + t.Errorf("failed to get user: %s", err) + } + if !reflect.DeepEqual(userResp.Users, wantResult) { + t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult) + } +} + +func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) { + ctx := context.Background() + protoResp := &api.ProtocolResponse{} + + if err := asAPI.Protocols(ctx, &api.ProtocolRequest{ + Protocol: proto, + }, protoResp); err != nil { + t.Errorf("failed to get Protocols: %s", err) + } + if !reflect.DeepEqual(protoResp.Protocols, wantResult) { + t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult) + } +} diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ac68f4bd4..528de63e8 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage( if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { newEventID := output.NewRoomEvent.Event.EventID() eventsReq := &api.QueryEventsByIDRequest{ + RoomID: output.NewRoomEvent.Event.RoomID(), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), } eventsRes := &api.QueryEventsByIDResponse{} diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go deleted file mode 100644 index f7f164877..000000000 --- a/appservice/inthttp/client.go +++ /dev/null @@ -1,84 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// HTTP paths for the internal HTTP APIs -const ( - AppServiceRoomAliasExistsPath = "/appservice/RoomAliasExists" - AppServiceUserIDExistsPath = "/appservice/UserIDExists" - AppServiceLocationsPath = "/appservice/locations" - AppServiceUserPath = "/appservice/users" - AppServiceProtocolsPath = "/appservice/protocols" -) - -// httpAppServiceQueryAPI contains the URL to an appservice query API and a -// reference to a httpClient used to reach it -type httpAppServiceQueryAPI struct { - appserviceURL string - httpClient *http.Client -} - -// NewAppserviceClient creates a AppServiceQueryAPI implemented by talking -// to a HTTP POST API. -// If httpClient is nil an error is returned -func NewAppserviceClient( - appserviceURL string, - httpClient *http.Client, -) (api.AppServiceInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") - } - return &httpAppServiceQueryAPI{appserviceURL, httpClient}, nil -} - -// RoomAliasExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) RoomAliasExists( - ctx context.Context, - request *api.RoomAliasExistsRequest, - response *api.RoomAliasExistsResponse, -) error { - return httputil.CallInternalRPCAPI( - "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath, - h.httpClient, ctx, request, response, - ) -} - -// UserIDExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) UserIDExists( - ctx context.Context, - request *api.UserIDExistsRequest, - response *api.UserIDExistsResponse, -) error { - return httputil.CallInternalRPCAPI( - "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) Locations(ctx context.Context, request *api.LocationRequest, response *api.LocationResponse) error { - return httputil.CallInternalRPCAPI( - "ASLocation", h.appserviceURL+AppServiceLocationsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) User(ctx context.Context, request *api.UserRequest, response *api.UserResponse) error { - return httputil.CallInternalRPCAPI( - "ASUser", h.appserviceURL+AppServiceUserPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpAppServiceQueryAPI) Protocols(ctx context.Context, request *api.ProtocolRequest, response *api.ProtocolResponse) error { - return httputil.CallInternalRPCAPI( - "ASProtocols", h.appserviceURL+AppServiceProtocolsPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go deleted file mode 100644 index ccf5c83d8..000000000 --- a/appservice/inthttp/server.go +++ /dev/null @@ -1,36 +0,0 @@ -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. -func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle( - AppServiceRoomAliasExistsPath, - httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists), - ) - - internalAPIMux.Handle( - AppServiceUserIDExistsPath, - httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), - ) - - internalAPIMux.Handle( - AppServiceProtocolsPath, - httputil.MakeInternalRPCAPI("AppserviceProtocols", a.Protocols), - ) - - internalAPIMux.Handle( - AppServiceLocationsPath, - httputil.MakeInternalRPCAPI("AppserviceLocations", a.Locations), - ) - - internalAPIMux.Handle( - AppServiceUserPath, - httputil.MakeInternalRPCAPI("AppserviceUser", a.User), - ) -} diff --git a/appservice/query/query.go b/appservice/query/query.go index 2348eab4b..0466f81d0 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -25,10 +25,10 @@ import ( "strings" "sync" - "github.com/opentracing/opentracing-go" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" ) @@ -50,8 +50,8 @@ func (a *AppServiceQueryAPI) RoomAliasExists( request *api.RoomAliasExistsRequest, response *api.RoomAliasExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceRoomAlias") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "ApplicationServiceRoomAlias") + defer trace.EndRegion() // Determine which application service should handle this request for _, appservice := range a.Cfg.Derived.ApplicationServices { @@ -117,8 +117,8 @@ func (a *AppServiceQueryAPI) UserIDExists( request *api.UserIDExistsRequest, response *api.UserIDExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceUserID") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "ApplicationServiceUserID") + defer trace.EndRegion() // Determine which application service should handle this request for _, appservice := range a.Cfg.Derived.ApplicationServices { diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 81c0f8049..585374738 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -936,4 +936,12 @@ fst Room state after a rejected message event is the same as before fst Room state after a rejected state event is the same as before fpb Federation publicRoom Name/topic keys are correct fed New federated private chats get full presence information (SYN-115) (10 subtests) -dvk Rejects invalid device keys \ No newline at end of file +dvk Rejects invalid device keys +rmv User can create and send/receive messages in a room with version 10 +rmv local user can join room with version 10 +rmv User can invite local user to room with version 10 +rmv remote user can join room with version 10 +rmv User can invite remote user to room with version 10 +rmv Remote user can backfill in a room with version 10 +rmv Can reject invites over federation for rooms with version 10 +rmv Can receive redactions from regular users over federation in room version 10 \ No newline at end of file diff --git a/build/dendritejs-pinecone/jsServer.go b/build/dendritejs-pinecone/jsServer.go deleted file mode 100644 index a2fc39d42..000000000 --- a/build/dendritejs-pinecone/jsServer.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build wasm -// +build wasm - -package main - -import ( - "bufio" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "syscall/js" -) - -// JSServer exposes an HTTP-like server interface which allows JS to 'send' requests to it. -type JSServer struct { - // The router which will service requests - Mux http.Handler -} - -// OnRequestFromJS is the function that JS will invoke when there is a new request. -// The JS function signature is: -// -// function(reqString: string): Promise<{result: string, error: string}> -// -// Usage is like: -// -// const res = await global._go_js_server.fetch(reqString); -// if (res.error) { -// // handle error: this is a 'network' error, not a non-2xx error. -// } -// const rawHttpResponse = res.result; -func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} { - // we HAVE to spawn a new goroutine and return immediately or else Go will deadlock - // if this request blocks at all e.g for /sync calls - httpStr := args[0].String() - promise := js.Global().Get("Promise").New(js.FuncOf(func(pthis js.Value, pargs []js.Value) interface{} { - // The initial callback code for new Promise() is also called on the critical path, which is why - // we need to put this in an immediately invoked goroutine. - go func() { - resolve := pargs[0] - resStr, err := h.handle(httpStr) - errStr := "" - if err != nil { - errStr = err.Error() - } - resolve.Invoke(map[string]interface{}{ - "result": resStr, - "error": errStr, - }) - }() - return nil - })) - return promise -} - -// handle invokes the http.ServeMux for this request and returns the raw HTTP response. -func (h *JSServer) handle(httpStr string) (resStr string, err error) { - req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(httpStr))) - if err != nil { - return - } - w := httptest.NewRecorder() - - h.Mux.ServeHTTP(w, req) - - res := w.Result() - var resBuffer strings.Builder - err = res.Write(&resBuffer) - return resBuffer.String(), err -} - -// ListenAndServe registers a variable in JS-land with the given namespace. This variable is -// a function which JS-land can call to 'send' HTTP requests. The function is attached to -// a global object called "_go_js_server". See OnRequestFromJS for more info. -func (h *JSServer) ListenAndServe(namespace string) { - globalName := "_go_js_server" - // register a hook in JS-land for it to invoke stuff - server := js.Global().Get(globalName) - if !server.Truthy() { - server = js.Global().Get("Object").New() - js.Global().Set(globalName, server) - } - - server.Set(namespace, js.FuncOf(h.OnRequestFromJS)) - - fmt.Printf("Listening for requests from JS on function %s.%s\n", globalName, namespace) - // Block forever to mimic http.ListenAndServe - select {} -} diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go deleted file mode 100644 index e070173aa..000000000 --- a/build/dendritejs-pinecone/main.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build wasm -// +build wasm - -package main - -import ( - "crypto/ed25519" - "encoding/hex" - "fmt" - "syscall/js" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/sirupsen/logrus" - - _ "github.com/matrix-org/go-sqlite3-js" - - pineconeConnections "github.com/matrix-org/pinecone/connections" - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeSessions "github.com/matrix-org/pinecone/sessions" -) - -var GitCommit string - -func init() { - fmt.Printf("[%s] dendrite.js starting...\n", GitCommit) -} - -const publicPeer = "wss://pinecone.matrix.org/public" -const keyNameEd25519 = "_go_ed25519_key" - -func readKeyFromLocalStorage() (key ed25519.PrivateKey, err error) { - localforage := js.Global().Get("localforage") - if !localforage.Truthy() { - err = fmt.Errorf("readKeyFromLocalStorage: no localforage") - return - } - // https://localforage.github.io/localForage/ - item, ok := await(localforage.Call("getItem", keyNameEd25519)) - if !ok || !item.Truthy() { - err = fmt.Errorf("readKeyFromLocalStorage: no key in localforage") - return - } - fmt.Println("Found key in localforage") - // extract []byte and make an ed25519 key - seed := make([]byte, 32, 32) - js.CopyBytesToGo(seed, item) - - return ed25519.NewKeyFromSeed(seed), nil -} - -func writeKeyToLocalStorage(key ed25519.PrivateKey) error { - localforage := js.Global().Get("localforage") - if !localforage.Truthy() { - return fmt.Errorf("writeKeyToLocalStorage: no localforage") - } - - // make a Uint8Array from the key's seed - seed := key.Seed() - jsSeed := js.Global().Get("Uint8Array").New(len(seed)) - js.CopyBytesToJS(jsSeed, seed) - // write it - localforage.Call("setItem", keyNameEd25519, jsSeed) - return nil -} - -// taken from https://go-review.googlesource.com/c/go/+/150917 - -// await waits until the promise v has been resolved or rejected and returns the promise's result value. -// The boolean value ok is true if the promise has been resolved, false if it has been rejected. -// If v is not a promise, v itself is returned as the value and ok is true. -func await(v js.Value) (result js.Value, ok bool) { - if v.Type() != js.TypeObject || v.Get("then").Type() != js.TypeFunction { - return v, true - } - done := make(chan struct{}) - onResolve := js.FuncOf(func(this js.Value, args []js.Value) interface{} { - result = args[0] - ok = true - close(done) - return nil - }) - defer onResolve.Release() - onReject := js.FuncOf(func(this js.Value, args []js.Value) interface{} { - result = args[0] - ok = false - close(done) - return nil - }) - defer onReject.Release() - v.Call("then", onResolve, onReject) - <-done - return -} - -func generateKey() ed25519.PrivateKey { - // attempt to look for a seed in JS-land and if it exists use it. - priv, err := readKeyFromLocalStorage() - if err == nil { - fmt.Println("Read key from localStorage") - return priv - } - // generate a new key - fmt.Println(err, " : Generating new ed25519 key") - _, priv, err = ed25519.GenerateKey(nil) - if err != nil { - logrus.Fatalf("Failed to generate ed25519 key: %s", err) - } - if err := writeKeyToLocalStorage(priv); err != nil { - fmt.Println("failed to write key to localStorage: ", err) - // non-fatal, we'll just have amnesia for a while - } - return priv -} - -func main() { - startup() - - // We want to block forever to let the fetch and libp2p handler serve the APIs - select {} -} - -func startup() { - sk := generateKey() - pk := sk.Public().(ed25519.PublicKey) - - pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) - pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pManager := pineconeConnections.NewConnectionManager(pRouter) - pManager.AddPeer("wss://pinecone.matrix.org/public") - - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" - cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" - cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" - cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" - cfg.SyncAPI.Database.ConnectionString = "file:/idb/dendritejs_syncapi.db" - cfg.KeyServer.Database.ConnectionString = "file:/idb/dendritejs_e2ekey.db" - cfg.Global.JetStream.StoragePath = "file:/idb/dendritejs/" - cfg.Global.TrustedIDServers = []string{} - cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.PrivateKey = sk - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - - if err := cfg.Derive(); err != nil { - logrus.Fatalf("Failed to derive values from config: %s", err) - } - base := base.NewBaseDendrite(cfg, "Monolith") - defer base.Close() // nolint: errcheck - - federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsAPI := roomserver.NewInternalAPI(base) - - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) - - asQuery := appservice.NewInternalAPI( - base, userAPI, rsAPI, - ) - rsAPI.SetAppserviceAPI(asQuery) - fedSenderAPI := federationapi.NewInternalAPI(base, federation, rsAPI, base.Caches, keyRing, true) - rsAPI.SetFederationAPI(fedSenderAPI, keyRing) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, pSessions), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asQuery, - FederationAPI: fedSenderAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - //ServerKeyAPI: serverKeyAPI, - ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), - } - monolith.AddAllPublicRoutes(base) - - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - p2pRouter := pSessions.Protocol("matrix").HTTP().Mux() - p2pRouter.Handle(httputil.PublicFederationPathPrefix, base.PublicFederationAPIMux) - p2pRouter.Handle(httputil.PublicMediaPathPrefix, base.PublicMediaAPIMux) - - // Expose the matrix APIs via fetch - for local traffic - go func() { - logrus.Info("Listening for service-worker fetch traffic") - s := JSServer{ - Mux: httpRouter, - } - s.ListenAndServe("fetch") - }() -} diff --git a/build/dendritejs-pinecone/main_noop.go b/build/dendritejs-pinecone/main_noop.go deleted file mode 100644 index 0cc7e47e5..000000000 --- a/build/dendritejs-pinecone/main_noop.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package main - -import "fmt" - -func main() { - fmt.Println("dendritejs: no-op when not compiling for WebAssembly") -} diff --git a/build/docker/Dockerfile.demo-pinecone b/build/docker/Dockerfile.demo-pinecone index facd1e3af..90f515167 100644 --- a/build/docker/Dockerfile.demo-pinecone +++ b/build/docker/Dockerfile.demo-pinecone @@ -17,6 +17,7 @@ RUN go build -trimpath -o bin/ ./cmd/create-account RUN go build -trimpath -o bin/ ./cmd/generate-keys FROM alpine:latest +RUN apk --update --no-cache add curl LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" diff --git a/build/docker/README.md b/build/docker/README.md index 7eb20d88f..b66cb864b 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -5,7 +5,6 @@ These are Docker images for Dendrite! They can be found on Docker Hub: - [matrixdotorg/dendrite-monolith](https://hub.docker.com/r/matrixdotorg/dendrite-monolith) for monolith deployments -- [matrixdotorg/dendrite-polylith](https://hub.docker.com/r/matrixdotorg/dendrite-polylith) for polylith deployments ## Dockerfiles @@ -15,7 +14,6 @@ repository, run: ``` docker build . --target monolith -t matrixdotorg/dendrite-monolith -docker build . --target polylith -t matrixdotorg/dendrite-monolith docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil ``` @@ -25,7 +23,6 @@ docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil There are two sample `docker-compose` files: - `docker-compose.monolith.yml` which runs a monolith Dendrite deployment -- `docker-compose.polylith.yml` which runs a polylith Dendrite deployment ## Configuration @@ -51,9 +48,9 @@ docker run --rm --entrypoint="" \ The key files will now exist in your current working directory, and can be mounted into place. -## Starting Dendrite as a monolith deployment +## Starting Dendrite -Create your config based on the [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) sample configuration file. +Create your config based on the [`dendrite-sample.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.yaml) sample configuration file. Then start the deployment: @@ -61,16 +58,6 @@ Then start the deployment: docker-compose -f docker-compose.monolith.yml up ``` -## Starting Dendrite as a polylith deployment - -Create your config based on the [`dendrite-sample.polylith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.polylith.yaml) sample configuration file. - -Then start the deployment: - -``` -docker-compose -f docker-compose.polylith.yml up -``` - ## Building the images The `build/docker/images-build.sh` script will build the base image, followed by diff --git a/build/docker/docker-compose.polylith.yml b/build/docker/docker-compose.polylith.yml deleted file mode 100644 index de0ab0aa2..000000000 --- a/build/docker/docker-compose.polylith.yml +++ /dev/null @@ -1,143 +0,0 @@ -version: "3.4" -services: - postgres: - hostname: postgres - image: postgres:14 - restart: always - volumes: - - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - # To persist your PostgreSQL databases outside of the Docker image, - # to prevent data loss, modify the following ./path_to path: - - ./path_to/postgresql:/var/lib/postgresql/data - environment: - POSTGRES_PASSWORD: itsasecret - POSTGRES_USER: dendrite - healthcheck: - test: ["CMD-SHELL", "pg_isready -U dendrite"] - interval: 5s - timeout: 5s - retries: 5 - networks: - - internal - - jetstream: - hostname: jetstream - image: nats:latest - command: | - --jetstream - --store_dir /var/lib/nats - --cluster_name Dendrite - volumes: - # To persist your NATS JetStream streams outside of the Docker image, - # prevent data loss, modify the following ./path_to path: - - ./path_to/nats:/var/lib/nats - networks: - - internal - - client_api: - hostname: client_api - image: matrixdotorg/dendrite-polylith:latest - command: clientapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - media_api: - hostname: media_api - image: matrixdotorg/dendrite-polylith:latest - command: mediaapi - volumes: - - ./config:/etc/dendrite - - ./media:/var/dendrite/media - networks: - - internal - restart: unless-stopped - - sync_api: - hostname: sync_api - image: matrixdotorg/dendrite-polylith:latest - command: syncapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - room_server: - hostname: room_server - image: matrixdotorg/dendrite-polylith:latest - command: roomserver - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - federation_api: - hostname: federation_api - image: matrixdotorg/dendrite-polylith:latest - command: federationapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - key_server: - hostname: key_server - image: matrixdotorg/dendrite-polylith:latest - command: keyserver - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - user_api: - hostname: user_api - image: matrixdotorg/dendrite-polylith:latest - command: userapi - volumes: - - ./config:/etc/dendrite - depends_on: - - jetstream - - postgres - networks: - - internal - restart: unless-stopped - - appservice_api: - hostname: appservice_api - image: matrixdotorg/dendrite-polylith:latest - command: appservice - volumes: - - ./config:/etc/dendrite - networks: - - internal - depends_on: - - jetstream - - postgres - - room_server - - user_api - restart: unless-stopped - -networks: - internal: - attachable: true diff --git a/build/docker/images-build.sh b/build/docker/images-build.sh index d97a701ed..bf227968f 100755 --- a/build/docker/images-build.sh +++ b/build/docker/images-build.sh @@ -7,6 +7,5 @@ TAG=${1:-latest} echo "Building tag '${TAG}'" docker build . --target monolith -t matrixdotorg/dendrite-monolith:${TAG} -docker build . --target polylith -t matrixdotorg/dendrite-monolith:${TAG} docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone:${TAG} docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil:${TAG} \ No newline at end of file diff --git a/build/docker/images-pull.sh b/build/docker/images-pull.sh index f3f98ce7c..7772ca747 100755 --- a/build/docker/images-pull.sh +++ b/build/docker/images-pull.sh @@ -5,4 +5,3 @@ TAG=${1:-latest} echo "Pulling tag '${TAG}'" docker pull matrixdotorg/dendrite-monolith:${TAG} -docker pull matrixdotorg/dendrite-polylith:${TAG} \ No newline at end of file diff --git a/build/docker/images-push.sh b/build/docker/images-push.sh index 248fdee2b..d166d355a 100755 --- a/build/docker/images-push.sh +++ b/build/docker/images-push.sh @@ -5,4 +5,3 @@ TAG=${1:-latest} echo "Pushing tag '${TAG}'" docker push matrixdotorg/dendrite-monolith:${TAG} -docker push matrixdotorg/dendrite-polylith:${TAG} \ No newline at end of file diff --git a/build/gobind-pinecone/build.sh b/build/gobind-pinecone/build.sh deleted file mode 100644 index 0f1b1aab9..000000000 --- a/build/gobind-pinecone/build.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh - -TARGET="" - -while getopts "ai" option -do - case "$option" - in - a) gomobile bind -v -target android -trimpath -ldflags="-s -w" github.com/matrix-org/dendrite/build/gobind-pinecone ;; - i) gomobile bind -v -target ios -trimpath -ldflags="" github.com/matrix-org/dendrite/build/gobind-pinecone ;; - *) echo "No target specified, specify -a or -i"; exit 1 ;; - esac -done \ No newline at end of file diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go deleted file mode 100644 index 9100ebf0f..000000000 --- a/build/gobind-pinecone/monolith.go +++ /dev/null @@ -1,519 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gobind - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "crypto/tls" - "encoding/hex" - "fmt" - "io" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "go.uber.org/atomic" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi" - userapiAPI "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" - - pineconeConnections "github.com/matrix-org/pinecone/connections" - pineconeMulticast "github.com/matrix-org/pinecone/multicast" - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" - "github.com/matrix-org/pinecone/types" - - _ "golang.org/x/mobile/bind" -) - -const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour -) - -type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI -} - -func (m *DendriteMonolith) PublicKey() string { - return m.PineconeRouter.PublicKey().String() -} - -func (m *DendriteMonolith) BaseURL() string { - return fmt.Sprintf("http://%s", m.listener.Addr().String()) -} - -func (m *DendriteMonolith) PeerCount(peertype int) int { - return m.PineconeRouter.PeerCount(peertype) -} - -func (m *DendriteMonolith) SessionCount() int { - return len(m.PineconeQUIC.Protocol("matrix").Sessions()) -} - -type InterfaceInfo struct { - Name string - Index int - Mtu int - Up bool - Broadcast bool - Loopback bool - PointToPoint bool - Multicast bool - Addrs string -} - -type InterfaceRetriever interface { - CacheCurrentInterfaces() int - GetCachedInterface(index int) *InterfaceInfo -} - -func (m *DendriteMonolith) RegisterNetworkCallback(intfCallback InterfaceRetriever) { - callback := func() []pineconeMulticast.InterfaceInfo { - count := intfCallback.CacheCurrentInterfaces() - intfs := []pineconeMulticast.InterfaceInfo{} - for i := 0; i < count; i++ { - iface := intfCallback.GetCachedInterface(i) - if iface != nil { - intfs = append(intfs, pineconeMulticast.InterfaceInfo{ - Name: iface.Name, - Index: iface.Index, - Mtu: iface.Mtu, - Up: iface.Up, - Broadcast: iface.Broadcast, - Loopback: iface.Loopback, - PointToPoint: iface.PointToPoint, - Multicast: iface.Multicast, - Addrs: iface.Addrs, - }) - } - } - return intfs - } - m.PineconeMulticast.RegisterNetworkCallback(callback) -} - -func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { - if enabled { - m.PineconeMulticast.Start() - } else { - m.PineconeMulticast.Stop() - m.DisconnectType(int(pineconeRouter.PeerTypeMulticast)) - } -} - -func (m *DendriteMonolith) SetStaticPeer(uri string) { - m.PineconeManager.RemovePeers() - for _, uri := range strings.Split(uri, ",") { - m.PineconeManager.AddPeer(strings.TrimSpace(uri)) - } -} - -func (m *DendriteMonolith) DisconnectType(peertype int) { - for _, p := range m.PineconeRouter.Peers() { - if int(peertype) == p.PeerType { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) - } - } -} - -func (m *DendriteMonolith) DisconnectZone(zone string) { - for _, p := range m.PineconeRouter.Peers() { - if zone == p.Zone { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) - } - } -} - -func (m *DendriteMonolith) DisconnectPort(port int) { - m.PineconeRouter.Disconnect(types.SwitchPortID(port), nil) -} - -func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) { - l, r := net.Pipe() - conduit := &Conduit{conn: r, port: 0} - go func() { - conduit.portMutex.Lock() - defer conduit.portMutex.Unlock() - - logrus.Errorf("Attempting authenticated connect") - var err error - if conduit.port, err = m.PineconeRouter.Connect( - l, - pineconeRouter.ConnectionZone(zone), - pineconeRouter.ConnectionPeerType(peertype), - ); err != nil { - logrus.Errorf("Authenticated connect failed: %s", err) - _ = l.Close() - _ = r.Close() - _ = conduit.Close() - return - } - logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port) - }() - return conduit, nil -} - -func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, error) { - pubkey := m.PineconeRouter.PublicKey() - userID := userutil.MakeUserID( - localpart, - gomatrixserverlib.ServerName(hex.EncodeToString(pubkey[:])), - ) - userReq := &userapiAPI.PerformAccountCreationRequest{ - AccountType: userapiAPI.AccountTypeUser, - Localpart: localpart, - Password: password, - } - userRes := &userapiAPI.PerformAccountCreationResponse{} - if err := m.userAPI.PerformAccountCreation(context.Background(), userReq, userRes); err != nil { - return userID, fmt.Errorf("userAPI.PerformAccountCreation: %w", err) - } - return userID, nil -} - -func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, error) { - accessTokenBytes := make([]byte, 16) - n, err := rand.Read(accessTokenBytes) - if err != nil { - return "", fmt.Errorf("rand.Read: %w", err) - } - loginReq := &userapiAPI.PerformDeviceCreationRequest{ - Localpart: localpart, - DeviceID: &deviceID, - AccessToken: hex.EncodeToString(accessTokenBytes[:n]), - } - loginRes := &userapiAPI.PerformDeviceCreationResponse{} - if err := m.userAPI.PerformDeviceCreation(context.Background(), loginReq, loginRes); err != nil { - return "", fmt.Errorf("userAPI.PerformDeviceCreation: %w", err) - } - if !loginRes.DeviceCreated { - return "", fmt.Errorf("device was not created") - } - return loginRes.Device.AccessToken, nil -} - -// nolint:gocyclo -func (m *DendriteMonolith) Start() { - var sk ed25519.PrivateKey - var pk ed25519.PublicKey - - keyfile := filepath.Join(m.StorageDirectory, "p2p.pem") - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := filepath.Join(m.StorageDirectory, "p2p.key") - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err = test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) - - var err error - m.listener, err = net.Listen("tcp", "localhost:65432") - if err != nil { - panic(err) - } - - m.logger = logrus.Logger{ - Out: BindLogger{}, - } - m.logger.SetOutput(BindLogger{}) - logrus.SetOutput(BindLogger{}) - - pineconeEventChannel := make(chan pineconeEvents.Event) - m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - m.PineconeRouter.EnableHopLimiting() - m.PineconeRouter.EnableWakeupBroadcasts() - m.PineconeRouter.Subscribe(pineconeEventChannel) - - m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) - m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) - m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) - - prefix := hex.EncodeToString(pk) - cfg := &config.Dendrite{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.PrivateKey = sk - cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.InMemory = false - cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(m.CacheDirectory, prefix)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(m.CacheDirectory, "search")) - if err = cfg.Derive(); err != nil { - panic(err) - } - - base := base.NewBaseDendrite(cfg, "Monolith") - defer base.Close() // nolint: errcheck - - federation := conn.CreateFederationClient(base, m.PineconeQUIC) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsAPI := roomserver.NewInternalAPI(base) - - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(m.userAPI) - - asAPI := appservice.NewInternalAPI(base, m.userAPI, rsAPI) - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(fsAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, m.PineconeQUIC), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: m.userAPI, - KeyAPI: keyAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - monolith.AddAllPublicRoutes(base) - - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := m.PineconeQUIC.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - h2s := &http2.Server{} - m.httpServer = &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 30 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: h2c.NewHandler(pMux, h2s), - } - - m.processContext = base.ProcessContext - - go func() { - m.logger.Info("Listening on ", cfg.Global.ServerName) - - switch m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Fatal(err) - } - }() - go func() { - logrus.Info("Listening on ", m.listener.Addr()) - - switch http.Serve(m.listener, httpRouter) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Fatal(err) - } - }() - - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: - case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - case pineconeEvents.BandwidthReport: - default: - } - } - }(pineconeEventChannel) -} - -func (m *DendriteMonolith) Stop() { - m.processContext.ShutdownDendrite() - _ = m.listener.Close() - m.PineconeMulticast.Stop() - _ = m.PineconeQUIC.Close() - _ = m.PineconeRouter.Close() - m.processContext.WaitForComponentsToFinish() -} - -const MaxFrameSize = types.MaxFrameSize - -type Conduit struct { - closed atomic.Bool - conn net.Conn - port types.SwitchPortID - portMutex sync.Mutex -} - -func (c *Conduit) Port() int { - c.portMutex.Lock() - defer c.portMutex.Unlock() - return int(c.port) -} - -func (c *Conduit) Read(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Read(b) -} - -func (c *Conduit) ReadCopy() ([]byte, error) { - if c.closed.Load() { - return nil, io.EOF - } - var buf [65535 * 2]byte - n, err := c.conn.Read(buf[:]) - if err != nil { - return nil, err - } - return buf[:n], nil -} - -func (c *Conduit) Write(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Write(b) -} - -func (c *Conduit) Close() error { - if c.closed.Load() { - return io.ErrClosedPipe - } - c.closed.Store(true) - return c.conn.Close() -} diff --git a/build/gobind-pinecone/platform_ios.go b/build/gobind-pinecone/platform_ios.go deleted file mode 100644 index a89ebfcd0..000000000 --- a/build/gobind-pinecone/platform_ios.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build ios -// +build ios - -package gobind - -/* -#cgo CFLAGS: -x objective-c -#cgo LDFLAGS: -framework Foundation -#import -void Log(const char *text) { - NSString *nss = [NSString stringWithUTF8String:text]; - NSLog(@"%@", nss); -} -*/ -import "C" -import "unsafe" - -type BindLogger struct { -} - -func (nsl BindLogger) Write(p []byte) (n int, err error) { - p = append(p, 0) - cstr := (*C.char)(unsafe.Pointer(&p[0])) - C.Log(cstr) - return len(p), nil -} diff --git a/build/gobind-pinecone/platform_other.go b/build/gobind-pinecone/platform_other.go deleted file mode 100644 index 2793026b8..000000000 --- a/build/gobind-pinecone/platform_other.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !ios -// +build !ios - -package gobind - -import "log" - -type BindLogger struct{} - -func (nsl BindLogger) Write(p []byte) (n int, err error) { - log.Println(string(p)) - return len(p), nil -} diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 248b6c324..32af611ae 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -127,8 +126,8 @@ func (m *DendriteMonolith) Start() { cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk @@ -149,7 +148,8 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg) + base.ConfigureAdminEndpoints() m.processContext = base.ProcessContext defer base.Close() // nolint: errcheck @@ -164,9 +164,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -185,7 +183,6 @@ func (m *DendriteMonolith) Start() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), @@ -193,9 +190,10 @@ func (m *DendriteMonolith) Start() { monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) yggRouter := mux.NewRouter() yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 79422e645..70bbe8f95 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -16,13 +16,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca @@ -30,4 +33,4 @@ EXPOSE 8008 8448 CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + exec /complement-cmd.sh diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 3a019fc20..0b80cfc40 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -12,18 +12,20 @@ FROM golang:1.18-stretch RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost +ENV COVER=0 EXPOSE 8008 8448 WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ - if test -f "/runtime/dendrite-monolith-server"; then \n\ + if test -f "/runtime/dendrite" && test -f "/runtime/dendrite-cover"; then \n\ echo "Skipping compilation; binaries exist" \n\ exit 0 \n\ fi \n\ cd /dendrite \n\ - go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ + go build -v -o /runtime /dendrite/cmd/dendrite \n\ + go test -c -cover -covermode=atomic -o /runtime/dendrite-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite \n\ ' > compile.sh && chmod +x compile.sh # This script runs Dendrite for us. Must be run in the /runtime directory. @@ -33,7 +35,8 @@ RUN echo '\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ + [ ${COVER} -eq 1 ] && exec ./dendrite-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ + exec ./dendrite --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 3faf43cc7..d4b6d3f75 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -34,13 +34,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 @@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA # Bump max_open_conns up here in the global database config sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + exec /complement-cmd.sh \ No newline at end of file diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh new file mode 100755 index 000000000..52b063d01 --- /dev/null +++ b/build/scripts/complement-cmd.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +# This script is intended to be used inside a docker container for Complement + +if [[ "${COVER}" -eq 1 ]]; then + echo "Running with coverage" + exec /dendrite/dendrite-cover \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + --test.coverprofile=complementcover.log +else + echo "Not running with coverage" + exec /dendrite/dendrite \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml +fi diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go new file mode 100644 index 000000000..300d3a88a --- /dev/null +++ b/clientapi/admin_test.go @@ -0,0 +1,230 @@ +package clientapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestAdminResetPassword(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for changing the password/login + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + bob: "", + vhUser: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + requestingUser *test.User + userID string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + {name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID}, + {name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s + {name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": "", + })}, + {name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)}, + {name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(6), + })}, + {name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(513), + })}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt) + } + + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + } + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestPurgeRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + roomID string + wantOK bool + }{ + {name: "Can purge existing room", wantOK: true, roomID: room.ID}, + {name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"}, + {name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + + }) +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 4bad1dcc2..a202bf10c 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/ratelimit" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -48,7 +49,7 @@ func TestLoginFromJSONReader(t *testing.T) { "password": "herpassword", "device_id": "adevice" }`, - WantUsername: "alice", + WantUsername: "@alice:example.com", WantDeviceID: "adevice", }, { @@ -179,7 +180,7 @@ func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req * return nil } res.Exists = true - res.Account = &uapi.Account{} + res.Account = &uapi.Account{UserID: userutil.MakeUserID(req.Localpart, req.ServerName)} return nil } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index d37e93a66..14ebdd038 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -96,7 +96,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } else { - username = strings.ToLower(r.Username()) + username = r.Username() } if username == "" { return nil, &util.JSONResponse{ @@ -146,6 +146,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } + // If we couldn't find the user by the lower cased localpart, try the provided + // localpart as is. if !res.Exists { err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, @@ -170,6 +172,9 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } + // Set the user, so login.Username() can do the right thing + r.Identifier.User = res.Account.UserID + r.User = res.Account.UserID r.Login.User = username return &r.Login, nil } diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index bcaae0c3e..545e7abad 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,18 +15,18 @@ package clientapi import ( + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/routing" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. @@ -39,7 +39,6 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { cfg := &base.Cfg.ClientAPI @@ -58,13 +57,10 @@ func AddPublicRoutes( } routing.Setup( - base.PublicClientAPIMux, - base.PublicWellKnownAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, + base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, + syncProducer, transactionsCache, fsAPI, extRoomsProvider, mscCfg, natsClient, ) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index be8073c33..a01f6b944 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,12 +1,14 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -14,14 +16,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -54,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -97,26 +98,77 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - serverName := cfg.Matrix.ServerName - localpart, ok := vars["localpart"] + roomID, ok := vars["roomID"] if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting user localpart."), + JSON: jsonerror.MissingArgument("Expecting room ID."), } } - if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { - localpart, serverName = l, s + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} + if err := rsAPI.PerformAdminPurgeRoom( + context.Background(), + &roomserverAPI.PerformAdminPurgeRoomRequest{ + RoomID: roomID, + }, + res, + ); err != nil { + return util.ErrorResponse(err) + } + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + +func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse { + if req.Body == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown("Missing request body"), + } + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + var localpart string + userID := vars["userID"] + localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } + accAvailableResp := &api.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{ + Localpart: localpart, + ServerName: serverName, + }, accAvailableResp); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.InternalAPIError(req.Context(), err), + } + } + if accAvailableResp.Available { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Unknown("User does not exist"), + } } request := struct { Password string `json:"password"` }{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), @@ -128,13 +180,18 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting non-empty password."), } } - updateReq := &userapi.PerformPasswordUpdateRequest{ + + if err = internal.ValidatePassword(request.Password); err != nil { + return *internal.PasswordResponse(err) + } + + updateReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, Password: request.Password, LogoutDevices: true, } - updateRes := &userapi.PerformPasswordUpdateResponse{} + updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -151,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } -func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse { _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") @@ -197,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien } } -func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index ad870993e..f8d3684fe 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -15,11 +15,11 @@ package routing import ( + "fmt" "html/template" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" ) @@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, cfg *config.ClientAPI, -) *util.JSONResponse { - sessionID := req.URL.Query().Get("session") +) { + // We currently only support "m.login.recaptcha", so fail early if that's not requested + if authType == authtypes.LoginTypeRecaptcha { + if !cfg.RecaptchaEnabled { + writeHTTPMessage(w, req, + "Recaptcha login is disabled on this Homeserver", + http.StatusBadRequest, + ) + return + } + } else { + writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) + return + } + sessionID := req.URL.Query().Get("session") if sessionID == "" { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Session ID not provided", http.StatusBadRequest, ) + return } serveRecaptcha := func() { @@ -130,70 +144,44 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - serveRecaptcha() - return nil - } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), - } + serveRecaptcha() + return } else if req.Method == http.MethodPost { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - clientIP := req.RemoteAddr - err := req.ParseForm() - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") - res := jsonerror.InternalServerError() - return &res - } - - response := req.Form.Get(cfg.RecaptchaFormField) - if err := validateRecaptcha(cfg, response, clientIP); err != nil { - util.GetLogger(req.Context()).Error(err) - return err - } - - // Success. Add recaptcha as a completed login flow - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) - - serveSuccess() - return nil + clientIP := req.RemoteAddr + err := req.ParseForm() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() + return } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), + response := req.Form.Get(cfg.RecaptchaFormField) + err = validateRecaptcha(cfg, response, clientIP) + switch err { + case ErrMissingResponse: + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() // serve the initial page again, instead of nothing + return + case ErrInvalidCaptcha: + w.WriteHeader(http.StatusUnauthorized) + serveRecaptcha() + return + case nil: + default: // something else failed + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + serveRecaptcha() + return } - } - return &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } -} -// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. -func checkRecaptchaEnabled( - cfg *config.ClientAPI, - w http.ResponseWriter, - req *http.Request, -) *util.JSONResponse { - if !cfg.RecaptchaEnabled { - return writeHTTPMessage(w, req, - "Recaptcha login is disabled on this Homeserver", - http.StatusBadRequest, - ) + // Success. Add recaptcha as a completed login flow + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + + serveSuccess() + return } - return nil + writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) } // writeHTTPMessage writes the given header and message to the HTTP response writer. @@ -201,13 +189,10 @@ func checkRecaptchaEnabled( func writeHTTPMessage( w http.ResponseWriter, req *http.Request, message string, header int, -) *util.JSONResponse { +) { w.WriteHeader(header) _, err := w.Write([]byte(message)) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") - res := jsonerror.InternalServerError() - return &res } - return nil } diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go new file mode 100644 index 000000000..534581bdd --- /dev/null +++ b/clientapi/routing/auth_fallback_test.go @@ -0,0 +1,149 @@ +package routing + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" +) + +func Test_AuthFallback(t *testing.T) { + base, _, _ := testrig.Base(nil) + defer base.Close() + + for _, useHCaptcha := range []bool{false, true} { + for _, recaptchaEnabled := range []bool{false, true} { + for _, wantErr := range []bool{false, true} { + t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { + // Set the defaults for each test + base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) + base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled + base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" + base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" + if useHCaptcha { + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" + base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" + base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" + base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" + } + cfgErrs := &config.ConfigErrors{} + base.Cfg.ClientAPI.Verify(cfgErrs) + if len(*cfgErrs) > 0 { + t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) + } + + req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil) + rec := httptest.NewRecorder() + + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if !recaptchaEnabled { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response body: %s", rec.Body.String()) + } + } else { + if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) { + t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) + } + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wantErr { + _, _ = w.Write([]byte(`{"success":false}`)) + return + } + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer srv.Close() // nolint: errcheck + + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + + // check the result after sending the captcha + req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + req.Form = url.Values{} + req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue") + rec = httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if recaptchaEnabled { + if !wantErr { + if rec.Code != http.StatusOK { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != successTemplate { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate) + } + } else { + if rec.Code != http.StatusUnauthorized { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized) + } + wantString := "Authentication" + if !strings.Contains(rec.Body.String(), wantString) { + t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String()) + } + } + } else { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate") + } + } + }) + } + } + } + + t.Run("unknown fallbacks are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented) + } + }) + + t.Run("unknown methods are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing 'response' is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552bd..e371d9214 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 000000000..1450ef4bd --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,156 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index ca6ecefd2..5c41df2eb 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -21,9 +21,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -34,8 +33,8 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, - keyserverAPI api.ClientKeyAPI, device *userapi.Device, - accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + keyserverAPI api.ClientKeyAPI, device *api.Device, + accountAPI api.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys( } } -func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { uploadReq := &api.PerformUploadDeviceSignaturesRequest{} uploadRes := &api.PerformUploadDeviceSignaturesResponse{} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 0c12b1117..3d60fcc3a 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -23,8 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) type uploadKeysRequest struct { @@ -32,7 +31,7 @@ type uploadKeysRequest struct { OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } -func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r uploadKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return timeout } -func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go new file mode 100644 index 000000000..b72db9d8b --- /dev/null +++ b/clientapi/routing/login_test.go @@ -0,0 +1,149 @@ +package routing + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestLogin(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bobUser := &test.User{ID: "@bob:test", AccountType: uapi.AccountTypeUser} + charlie := &test.User{ID: "@Charlie:test", AccountType: uapi.AccountTypeUser} + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.ClientAPI.RateLimiting.Enabled = false + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for /login + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil) + + // Create password + password := util.RandomString(8) + + // create the users + for _, u := range []*test.User{aliceAdmin, bobUser, vhUser, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + } + + testCases := []struct { + name string + userID string + wantOK bool + }{ + { + name: "aliceAdmin can login", + userID: aliceAdmin.ID, + wantOK: true, + }, + { + name: "bobUser can login", + userID: bobUser.ID, + wantOK: true, + }, + { + name: "vhuser can login", + userID: vhUser.ID, + wantOK: true, + }, + { + name: "bob with uppercase can login", + userID: "@Bob:test", + wantOK: true, + }, + { + name: "Charlie can login (existing uppercase)", + userID: charlie.ID, + wantOK: true, + }, + { + name: "Charlie can not login with lowercase userID", + userID: strings.ToLower(charlie.ID), + wantOK: false, + }, + } + + ctx := context.Background() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": tc.userID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + + t.Logf("Response: %s", rec.Body.String()) + // get the response + resp := loginResponse{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + // everything OK + if !tc.wantOK && resp.AccessToken == "" { + return + } + if tc.wantOK && resp.AccessToken == "" { + t.Fatalf("expected accessToken after successful login but got none: %+v", resp) + } + + devicesResp := &uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: resp.UserID}, devicesResp); err != nil { + t.Fatal(err) + } + for _, dev := range devicesResp.Devices { + // We expect the userID on the device to be the same as resp.UserID + if dev.UserID != resp.UserID { + t.Fatalf("unexpected userID on device: %s", dev.UserID) + } + } + }) + } + }) +} diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 800e87512..ce0d150f0 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -147,8 +148,8 @@ func Password( } // Check the new password strength. - if resErr = validatePassword(r.NewPassword); resErr != nil { - return *resErr + if err := internal.ValidatePassword(r.NewPassword); err != nil { + return *internal.PasswordResponse(err) } // Ask the user API to perform the password change. diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 7841b3b07..f86bbc8fd 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -57,7 +57,7 @@ func SendRedaction( } } - ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) + ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID) if ev == nil { return util.JSONResponse{ Code: 400, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index e1bb1555c..472920697 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -18,18 +18,19 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/url" - "regexp" "sort" "strconv" "strings" "sync" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" @@ -60,12 +61,7 @@ var ( ) ) -const ( - minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based - maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain - sessionIDLength = 24 -) +const sessionIDLength = 24 // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. @@ -200,8 +196,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { } var ( - sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + sessions = newSessionsDict() ) // registerRequest represents the submitted registration request. @@ -265,10 +260,9 @@ func newUserInteractiveResponse( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register type registerResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token,omitempty"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id,omitempty"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token,omitempty"` + DeviceID string `json:"device_id,omitempty"` } // recaptchaResponse represents the HTTP response from a Google Recaptcha server @@ -279,83 +273,28 @@ type recaptchaResponse struct { ErrorCodes []int `json:"error-codes"` } -// validateUsername returns an error response if the username is invalid -func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } else if localpart[0] == '_' { // Regex checks its not a zero length string - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - } - } - return nil -} - -// validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } - return nil -} - -// validatePassword returns an error response if the password is invalid -func validatePassword(password string) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - } - } else if len(password) > 0 && len(password) < minPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - } - } - return nil -} +var ( + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") +) // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, response string, clientip string, -) *util.JSONResponse { +) error { ip, _, _ := net.SplitHostPort(clientip) if !cfg.RecaptchaEnabled { - return &util.JSONResponse{ - Code: http.StatusConflict, - JSON: jsonerror.Unknown("Captcha registration is disabled"), - } + return ErrCaptchaDisabled } if response == "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha response is required"), - } + return ErrMissingResponse } - // Make a POST request to Google's API to check the captcha response + // Make a POST request to the captcha provider API to check the captcha response resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI, url.Values{ "secret": {cfg.RecaptchaPrivateKey}, @@ -365,10 +304,7 @@ func validateRecaptcha( ) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"), - } + return err } // Close the request once we're finishing reading from it @@ -378,25 +314,16 @@ func validateRecaptcha( var r recaptchaResponse body, err := io.ReadAll(resp.Body) if err != nil { - return &util.JSONResponse{ - Code: http.StatusGatewayTimeout, - JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), - } + return err } err = json.Unmarshal(body, &r) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()), - } + return err } // Check that we received a "success" if !r.Success { - return &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."), - } + return ErrInvalidCaptcha } return nil } @@ -528,8 +455,8 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { - return "", err + if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { + return "", internal.UsernameResponse(err) } // No errors, registration valid @@ -584,15 +511,12 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { - r.Username, r.ServerName = l, d - } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } // Don't allow numeric usernames less than MAX_INT64. - if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), @@ -604,7 +528,7 @@ func Register( ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { + if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } @@ -621,8 +545,8 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } case accessTokenErr == nil: // Non-spec-compliant case (the access_token is specified but the login @@ -634,12 +558,12 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } } - if resErr := validatePassword(r.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(r.Password); err != nil { + return *internal.PasswordResponse(err) } logger := util.GetLogger(req.Context()) @@ -717,7 +641,6 @@ func handleGuestRegistration( JSON: registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: res.Account.ServerName, DeviceID: devRes.Device.ID, }, } @@ -782,9 +705,18 @@ func handleRegistrationFlow( switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response - resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) - if resErr != nil { - return *resErr + err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) + switch err { + case ErrCaptchaDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + case ErrMissingResponse: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + case ErrInvalidCaptcha: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} } // Add Recaptcha to the list of completed registration stages @@ -874,7 +806,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, + req.Context(), userAPI, r.Username, r.ServerName, "", "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil, ) } @@ -894,7 +826,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, + req.Context(), userAPI, r.Username, r.ServerName, "", r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid, ) } @@ -917,10 +849,10 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, + username string, serverName gomatrixserverlib.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, - displayName, deviceID *string, + deviceDisplayName, deviceID *string, accType userapi.AccountType, threePid *authtypes.ThreePID, ) util.JSONResponse { @@ -984,8 +916,7 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, accRes.Account.ServerName), - HomeServer: accRes.Account.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), }, } } @@ -998,12 +929,28 @@ func completeRegistration( } } + if displayName != "" { + nameReq := userapi.PerformUpdateDisplayNameRequest{ + Localpart: username, + ServerName: serverName, + DisplayName: displayName, + } + var nameRes userapi.PerformUpdateDisplayNameResponse + err = userAPI.SetDisplayName(ctx, &nameReq, &nameRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + } + } + } + var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, ServerName: serverName, AccessToken: token, - DeviceDisplayName: displayName, + DeviceDisplayName: deviceDisplayName, DeviceID: deviceID, IPAddr: ipAddr, UserAgent: userAgent, @@ -1018,7 +965,6 @@ func completeRegistration( result := registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: accRes.Account.ServerName, DeviceID: devRes.Device.ID, } sessions.addCompletedRegistration(sessionID, result) @@ -1114,8 +1060,8 @@ func RegisterAvailable( } } - if err := validateUsername(username, domain); err != nil { - return *err + if err := internal.ValidateUsername(username, domain); err != nil { + return *internal.UsernameResponse(err) } // Check if this username is reserved by an application service @@ -1177,11 +1123,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil { + return *internal.UsernameResponse(err) } - if resErr := validatePassword(ssrr.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(ssrr.Password); err != nil { + return *internal.PasswordResponse(err) } deviceID := "shared_secret_registration" @@ -1189,5 +1135,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.DisplayName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) } diff --git a/clientapi/routing/register_secret.go b/clientapi/routing/register_secret.go index 1a974b77a..f384b604a 100644 --- a/clientapi/routing/register_secret.go +++ b/clientapi/routing/register_secret.go @@ -18,12 +18,13 @@ import ( ) type SharedSecretRegistrationRequest struct { - User string `json:"username"` - Password string `json:"password"` - Nonce string `json:"nonce"` - MacBytes []byte - MacStr string `json:"mac"` - Admin bool `json:"admin"` + User string `json:"username"` + Password string `json:"password"` + Nonce string `json:"nonce"` + MacBytes []byte + MacStr string `json:"mac"` + Admin bool `json:"admin"` + DisplayName string `json:"displayname,omitempty"` } func NewSharedSecretRegistrationRequest(reader io.ReadCloser) (*SharedSecretRegistrationRequest, error) { diff --git a/clientapi/routing/register_secret_test.go b/clientapi/routing/register_secret_test.go index a2ed35853..ca265d237 100644 --- a/clientapi/routing/register_secret_test.go +++ b/clientapi/routing/register_secret_test.go @@ -10,7 +10,7 @@ import ( func TestSharedSecretRegister(t *testing.T) { // these values have come from a local synapse instance to ensure compatibility - jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice"}`) + jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) sharedSecret := "dendritetest" req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 85846c7d6..2c8b2c275 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -15,12 +15,30 @@ package routing import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" "regexp" + "strings" "testing" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" + "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" ) var ( @@ -182,8 +200,8 @@ func TestValidationOfApplicationServices(t *testing.T) { // Set up a config fakeConfig := &config.Dendrite{} fakeConfig.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) fakeConfig.Global.ServerName = "localhost" fakeConfig.ClientAPI.Derived.ApplicationServices = []config.ApplicationService{fakeApplicationService} @@ -264,3 +282,378 @@ func TestSessionCleanUp(t *testing.T) { } }) } + +func Test_register(t *testing.T) { + testCases := []struct { + name string + kind string + password string + username string + loginType string + forceEmpty bool + registrationDisabled bool + guestsDisabled bool + enableRecaptcha bool + captchaBody string + wantResponse util.JSONResponse + }{ + { + name: "disallow guests", + kind: "guest", + guestsDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + }, + }, + { + name: "allow guests", + kind: "guest", + }, + { + name: "unknown login type", + loginType: "im.not.known", + wantResponse: util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + }, + }, + { + name: "disabled registration", + registrationDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + }, + }, + { + name: "successful registration, numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + }, + { + name: "successful registration", + username: "success", + }, + { + name: "failing registration - user already exists", + username: "success", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + }, + }, + { + name: "successful registration uppercase username", + username: "LOWERCASED", // this is going to be lower-cased + }, + { + name: "invalid username", + username: "#totalyNotValid", + wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), + }, + { + name: "numeric username is forbidden", + username: "1337", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + }, + }, + { + name: "disabled recaptcha login", + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + }, + }, + { + name: "enabled recaptcha, no response defined", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + }, + }, + { + name: "invalid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `notvalid`, + wantResponse: util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + }, + }, + { + name: "valid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `success`, + }, + { + name: "captcha invalid from remote", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `i should fail for other reasons`, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.enableRecaptcha { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + response := r.Form.Get("response") + + // Respond with valid JSON or no JSON at all to test happy/error cases + switch response { + case "success": + json.NewEncoder(w).Encode(recaptchaResponse{Success: true}) + case "notvalid": + json.NewEncoder(w).Encode(recaptchaResponse{Success: false}) + default: + + } + })) + defer srv.Close() + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + } + + if err := base.Cfg.Derive(); err != nil { + t.Fatalf("failed to derive config: %s", err) + } + + base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha + base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + + if tc.kind == "" { + tc.kind = "user" + } + if tc.password == "" && !tc.forceEmpty { + tc.password = "someRandomPassword" + } + if tc.username == "" && !tc.forceEmpty { + tc.username = "valid" + } + if tc.loginType == "" { + tc.loginType = "m.login.dummy" + } + + reg := registerRequest{ + Password: tc.password, + Username: tc.username, + } + + body := &bytes.Buffer{} + err := json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) + + resp := Register(req, userAPI, &base.Cfg.ClientAPI) + t.Logf("Resp: %+v", resp) + + // The first request should return a userInteractiveResponse + switch r := resp.JSON.(type) { + case userInteractiveResponse: + // Check that the flows are the ones we configured + if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + } + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + } + return + case registerResponse: + // this should only be possible on guest user registration, never for normal users + if tc.kind != "guest" { + t.Fatalf("got register response on first request: %+v", r) + } + // assert we've got a UserID, AccessToken and DeviceID + if r.UserID == "" { + t.Fatalf("missing userID in response") + } + if r.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + if r.DeviceID == "" { + t.Fatalf("missing deviceID in response") + } + return + default: + t.Logf("Got response: %T", resp.JSON) + } + + // If we reached this, we should have received a UIA response + uia, ok := resp.JSON.(userInteractiveResponse) + if !ok { + t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) + } + t.Logf("%+v", uia) + + // Register the user + reg.Auth = authDict{ + Type: authtypes.LoginType(tc.loginType), + Session: uia.Session, + } + + if tc.captchaBody != "" { + reg.Auth.Response = tc.captchaBody + } + + dummy := "dummy" + reg.DeviceID = &dummy + reg.InitialDisplayName = &dummy + reg.Type = authtypes.LoginType(tc.loginType) + + err = json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req = httptest.NewRequest(http.MethodPost, "/", body) + + resp = Register(req, userAPI, &base.Cfg.ClientAPI) + + switch resp.JSON.(type) { + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case util.JSONResponse: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + } + + rr, ok := resp.JSON.(registerResponse) + if !ok { + t.Fatalf("expected a registerresponse, got %T", resp.JSON) + } + + // validate the response + if tc.forceEmpty { + // when not supplying a username, one will be generated. Given this _SHOULD_ be + // the second user, set the username accordingly + reg.Username = "2" + } + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) + } + if rr.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + }) + } + }) +} + +func TestRegisterUserWithDisplayName(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + deviceName, deviceID := "deviceName", "deviceID" + expectedDisplayName := "DisplayName" + response := completeRegistration( + base.Context(), + userAPI, + "user", + "server", + expectedDisplayName, + "password", + "", + "localhost", + "user agent", + "session", + false, + &deviceName, + &deviceID, + api.AccountTypeAdmin, + nil, + ) + + assert.Equal(t, http.StatusOK, response.Code) + + req := api.QueryProfileRequest{UserID: "@user:server"} + var res api.QueryProfileResponse + err := userAPI.QueryProfile(base.Context(), &req, &res) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, res.DisplayName) + }) +} + +func TestRegisterAdminUsingSharedSecret(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + sharedSecret := "dendritetest" + base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret + + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + expectedDisplayName := "rabbit" + jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) + req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) + assert.NoError(t, err) + if err != nil { + t.Fatalf("failed to read request: %s", err) + } + + r := NewSharedSecretRegistration(sharedSecret) + + // force the nonce to be known + r.nonces.Set(req.Nonce, true, cache.DefaultExpiration) + + _, err = r.IsValidMacLogin(req.Nonce, req.User, req.Password, req.Admin, req.MacBytes) + assert.NoError(t, err) + + body := &bytes.Buffer{} + err = json.NewEncoder(body).Encode(req) + assert.NoError(t, err) + ssrr := httptest.NewRequest(http.MethodPost, "/", body) + + response := handleSharedSecretRegistration( + &base.Cfg.ClientAPI, + userAPI, + r, + ssrr, + ) + assert.Equal(t, http.StatusOK, response.Code) + + profilReq := api.QueryProfileRequest{UserID: "@alice:server"} + var profileRes api.QueryProfileResponse + err = userAPI.QueryProfile(base.Context(), &profilReq, &profileRes) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, profileRes.DisplayName) + }) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f74e1837e..bbb1e0f9f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -18,8 +18,11 @@ import ( "context" "net/http" "strings" + "sync" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -36,11 +39,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -50,7 +51,7 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + base *base.BaseDendrite, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, @@ -60,11 +61,17 @@ func Setup( syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { - prometheus.MustRegister(amtRegUsers, sendEventDuration) + publicAPIMux := base.PublicClientAPIMux + wkMux := base.PublicWellKnownAPIMux + synapseAdminRouter := base.SynapseAdminMux + dendriteAdminRouter := base.DendriteAdminMux + + if base.EnableMetrics { + prometheus.MustRegister(amtRegUsers, sendEventDuration) + } rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin) @@ -159,6 +166,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", + httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminPurgeRoom(req, cfg, device, rsAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) @@ -179,25 +192,31 @@ func Setup( dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminMarkAsStale(req, cfg, keyAPI) + return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) - if err != nil { - logrus.WithError(err).Fatal("unable to get account for sending sending server notices") - } + var serverNotificationSender *userapi.Device + var err error + notificationSenderOnce := &sync.Once{} synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r } - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + var vars map[string]string + vars, err = httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -213,6 +232,12 @@ func Setup( synapseAdminRouter.Handle("/admin/v1/send_server_notice", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -644,9 +669,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - return AuthFallback(w, req, vars["authType"], cfg) + AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -1358,11 +1383,11 @@ func Setup( // Cross-signing device keys postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) + return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) }) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceSignatures(req, keyAPI, device) + return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) @@ -1374,22 +1399,22 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req, keyAPI, device) + return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return ClaimKeys(req, keyAPI) + return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index c8e239f29..772778680 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,10 +25,10 @@ import ( "io" "net/http" "os" - "regexp" "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/sirupsen/logrus" @@ -58,15 +58,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) - timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Deprecated") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ @@ -95,20 +94,21 @@ func main() { os.Exit(1) } - if !validUsernameRegex.MatchString(*username) { - logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil { + logrus.WithError(err).Error("Specified username is invalid") os.Exit(1) } - if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 { - logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) - } - pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { logrus.Fatalln(err) } + if err = internal.ValidatePassword(pass); err != nil { + logrus.WithError(err).Error("Specified password is invalid") + os.Exit(1) + } + cl.Timeout = *timeout accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) @@ -177,7 +177,7 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a defer regResp.Body.Close() // nolint: errcheck if regResp.StatusCode < 200 || regResp.StatusCode >= 300 { body, _ = io.ReadAll(regResp.Body) - return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) + return "", fmt.Errorf("got HTTP %d error from server: %s", regResp.StatusCode, string(body)) } r, err := io.ReadAll(regResp.Body) if err != nil { diff --git a/cmd/dendrite-demo-pinecone/README.md b/cmd/dendrite-demo-pinecone/README.md deleted file mode 100644 index d6dd95905..000000000 --- a/cmd/dendrite-demo-pinecone/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# Pinecone Demo - -This is the Dendrite Pinecone demo! It's easy to get started. - -To run the homeserver, start at the root of the Dendrite repository and run: - -``` -go run ./cmd/dendrite-demo-pinecone -``` - -To connect to the static Pinecone peer used by the mobile demos run: - -``` -go run ./cmd/dendrite-demo-pinecone -peer wss://pinecone.matrix.org/public -``` - -The following command line arguments are accepted: - -* `-peer tcp://a.b.c.d:e` to specify a static Pinecone peer to connect to - you will need to supply this if you do not have another Pinecone node on your network -* `-port 12345` to specify a port to listen on for client connections - -Then point your favourite Matrix client to the homeserver URL`http://localhost:8008` (or whichever `-port` you specified), create an account and log in. - -If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer. - -Once logged in, you should be able to open the room directory or join a room by its ID. diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go deleted file mode 100644 index a91434f62..000000000 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package conn - -import ( - "context" - "fmt" - "net" - "net/http" - "strings" - - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/gomatrixserverlib" - "nhooyr.io/websocket" - - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeSessions "github.com/matrix-org/pinecone/sessions" -) - -func ConnectToPeer(pRouter *pineconeRouter.Router, peer string) error { - var parent net.Conn - if strings.HasPrefix(peer, "ws://") || strings.HasPrefix(peer, "wss://") { - ctx := context.Background() - c, _, err := websocket.Dial(ctx, peer, nil) - if err != nil { - return fmt.Errorf("websocket.DefaultDialer.Dial: %w", err) - } - parent = websocket.NetConn(ctx, c, websocket.MessageBinary) - } else { - var err error - parent, err = net.Dial("tcp", peer) - if err != nil { - return fmt.Errorf("net.Dial: %w", err) - } - } - if parent == nil { - return fmt.Errorf("failed to wrap connection") - } - _, err := pRouter.Connect( - parent, - pineconeRouter.ConnectionZone("static"), - pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), - pineconeRouter.ConnectionURI(peer), - ) - return err -} - -type RoundTripper struct { - inner *http.Transport -} - -func (y *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.URL.Scheme = "http" - return y.inner.RoundTrip(req) -} - -func createTransport(s *pineconeSessions.Sessions) *http.Transport { - proto := s.Protocol("matrix") - tr := &http.Transport{ - DisableKeepAlives: false, - Dial: proto.Dial, - DialContext: proto.DialContext, - DialTLS: proto.DialTLS, - DialTLSContext: proto.DialTLSContext, - } - tr.RegisterProtocol( - "matrix", &RoundTripper{ - inner: &http.Transport{ - DisableKeepAlives: false, - Dial: proto.Dial, - DialContext: proto.DialContext, - DialTLS: proto.DialTLS, - DialTLSContext: proto.DialTLSContext, - }, - }, - ) - return tr -} - -func CreateClient( - base *base.BaseDendrite, s *pineconeSessions.Sessions, -) *gomatrixserverlib.Client { - return gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(createTransport(s)), - ) -} - -func CreateFederationClient( - base *base.BaseDendrite, s *pineconeSessions.Sessions, -) *gomatrixserverlib.FederationClient { - return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.SigningIdentities(), - gomatrixserverlib.WithTransport(createTransport(s)), - ) -} diff --git a/cmd/dendrite-demo-pinecone/conn/ws.go b/cmd/dendrite-demo-pinecone/conn/ws.go deleted file mode 100644 index ed85abd51..000000000 --- a/cmd/dendrite-demo-pinecone/conn/ws.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package conn - -import ( - "io" - "net" - "time" - - "github.com/gorilla/websocket" -) - -func WrapWebSocketConn(c *websocket.Conn) *WebSocketConn { - return &WebSocketConn{c: c} -} - -type WebSocketConn struct { - r io.Reader - c *websocket.Conn -} - -func (c *WebSocketConn) Write(p []byte) (int, error) { - err := c.c.WriteMessage(websocket.BinaryMessage, p) - if err != nil { - return 0, err - } - return len(p), nil -} - -func (c *WebSocketConn) Read(p []byte) (int, error) { - for { - if c.r == nil { - // Advance to next message. - var err error - _, c.r, err = c.c.NextReader() - if err != nil { - return 0, err - } - } - n, err := c.r.Read(p) - if err == io.EOF { - // At end of message. - c.r = nil - if n > 0 { - return n, nil - } else { - // No data read, continue to next message. - continue - } - } - return n, err - } -} - -func (c *WebSocketConn) Close() error { - return c.c.Close() -} - -func (c *WebSocketConn) LocalAddr() net.Addr { - return c.c.LocalAddr() -} - -func (c *WebSocketConn) RemoteAddr() net.Addr { - return c.c.RemoteAddr() -} - -func (c *WebSocketConn) SetDeadline(t time.Time) error { - if err := c.SetReadDeadline(t); err != nil { - return err - } - if err := c.SetWriteDeadline(t); err != nil { - return err - } - return nil -} - -func (c *WebSocketConn) SetReadDeadline(t time.Time) error { - return c.c.SetReadDeadline(t) -} - -func (c *WebSocketConn) SetWriteDeadline(t time.Time) error { - return c.c.SetWriteDeadline(t) -} diff --git a/cmd/dendrite-demo-pinecone/defaults/defaults.go b/cmd/dendrite-demo-pinecone/defaults/defaults.go deleted file mode 100644 index c92493137..000000000 --- a/cmd/dendrite-demo-pinecone/defaults/defaults.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package defaults - -import "github.com/matrix-org/gomatrixserverlib" - -var DefaultServerNames = map[gomatrixserverlib.ServerName]struct{}{ - "3bf0258d23c60952639cc4c69c71d1508a7d43a0475d9000ff900a1848411ec7": {}, -} diff --git a/cmd/dendrite-demo-pinecone/embed/embed_elementweb.go b/cmd/dendrite-demo-pinecone/embed/embed_elementweb.go deleted file mode 100644 index d37362e21..000000000 --- a/cmd/dendrite-demo-pinecone/embed/embed_elementweb.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build elementweb -// +build elementweb - -package embed - -import ( - "fmt" - "io" - "net/http" - "regexp" - - "github.com/gorilla/mux" - "github.com/tidwall/sjson" -) - -// From within the Element Web directory: -// go run github.com/mjibson/esc -o /path/to/dendrite/internal/embed/fs_elementweb.go -private -pkg embed . - -var cssFile = regexp.MustCompile("\\.css$") -var jsFile = regexp.MustCompile("\\.js$") - -type mimeFixingHandler struct { - fs http.Handler -} - -func (h mimeFixingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ruri := r.RequestURI - fmt.Println(ruri) - switch { - case cssFile.MatchString(ruri): - w.Header().Set("Content-Type", "text/css") - case jsFile.MatchString(ruri): - w.Header().Set("Content-Type", "application/javascript") - default: - } - h.fs.ServeHTTP(w, r) -} - -func Embed(rootMux *mux.Router, listenPort int, serverName string) { - embeddedFS := _escFS(false) - embeddedServ := mimeFixingHandler{http.FileServer(embeddedFS)} - - rootMux.NotFoundHandler = embeddedServ - rootMux.HandleFunc("/config.json", func(w http.ResponseWriter, r *http.Request) { - url := fmt.Sprintf("http://%s:%d", r.Header("Host"), listenPort) - configFile, err := embeddedFS.Open("/config.sample.json") - if err != nil { - w.WriteHeader(500) - io.WriteString(w, "Couldn't open the file: "+err.Error()) - return - } - configFileInfo, err := configFile.Stat() - if err != nil { - w.WriteHeader(500) - io.WriteString(w, "Couldn't stat the file: "+err.Error()) - return - } - buf := make([]byte, configFileInfo.Size()) - n, err := configFile.Read(buf) - if err != nil { - w.WriteHeader(500) - io.WriteString(w, "Couldn't read the file: "+err.Error()) - return - } - if int64(n) != configFileInfo.Size() { - w.WriteHeader(500) - io.WriteString(w, "The returned file size didn't match what we expected") - return - } - js, _ := sjson.SetBytes(buf, "default_server_config.m\\.homeserver.base_url", url) - js, _ = sjson.SetBytes(js, "default_server_config.m\\.homeserver.server_name", serverName) - js, _ = sjson.SetBytes(js, "brand", fmt.Sprintf("Element %s", serverName)) - js, _ = sjson.SetBytes(js, "disable_guests", true) - js, _ = sjson.SetBytes(js, "disable_3pid_login", true) - js, _ = sjson.DeleteBytes(js, "welcomeUserId") - _, _ = w.Write(js) - }) - - fmt.Println("*-------------------------------*") - fmt.Println("| This build includes Element Web! |") - fmt.Println("*-------------------------------*") - fmt.Println("Point your browser to:", url) - fmt.Println() -} diff --git a/cmd/dendrite-demo-pinecone/embed/embed_other.go b/cmd/dendrite-demo-pinecone/embed/embed_other.go deleted file mode 100644 index 94360fce6..000000000 --- a/cmd/dendrite-demo-pinecone/embed/embed_other.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !elementweb -// +build !elementweb - -package embed - -import "github.com/gorilla/mux" - -func Embed(_ *mux.Router, _ int, _ string) { - -} diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go deleted file mode 100644 index 421b17d56..000000000 --- a/cmd/dendrite-demo-pinecone/main.go +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "crypto/ed25519" - "crypto/tls" - "encoding/hex" - "flag" - "fmt" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "github.com/gorilla/mux" - "github.com/gorilla/websocket" - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/gomatrixserverlib" - - pineconeConnections "github.com/matrix-org/pinecone/connections" - pineconeMulticast "github.com/matrix-org/pinecone/multicast" - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" - - "github.com/sirupsen/logrus" -) - -var ( - instanceName = flag.String("name", "dendrite-p2p-pinecone", "the name of this P2P demo instance") - instancePort = flag.Int("port", 8008, "the port that the client API will listen on") - instancePeer = flag.String("peer", "", "the static Pinecone peers to connect to, comma separated-list") - instanceListen = flag.String("listen", ":0", "the port Pinecone peers can connect to") - instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") -) - -// nolint:gocyclo -func main() { - flag.Parse() - internal.SetupPprof() - - var pk ed25519.PublicKey - var sk ed25519.PrivateKey - - // iterate through the cli args and check if the config flag was set - configFlagSet := false - for _, arg := range os.Args { - if arg == "--config" || arg == "-config" { - configFlagSet = true - break - } - } - - cfg := &config.Dendrite{} - - // use custom config if config flag is set - if configFlagSet { - cfg = setup.ParseFlags(true) - sk = cfg.Global.PrivateKey - pk = sk.Public().(ed25519.PublicKey) - } else { - keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := *instanceName + ".key" - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err := test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - var err error - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) - - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) - cfg.Global.PrivateKey = sk - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(*instanceDir, *instanceName))) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.MediaAPI.BasePath = config.Path(*instanceDir) - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir) - if err := cfg.Derive(); err != nil { - panic(err) - } - } - - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - - base := base.NewBaseDendrite(cfg, "Monolith") - defer base.Close() // nolint: errcheck - - pineconeEventChannel := make(chan pineconeEvents.Event) - pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - pRouter.EnableHopLimiting() - pRouter.EnableWakeupBroadcasts() - pRouter.Subscribe(pineconeEventChannel) - - pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) - pManager := pineconeConnections.NewConnectionManager(pRouter, nil) - pMulticast.Start() - if instancePeer != nil && *instancePeer != "" { - for _, peer := range strings.Split(*instancePeer, ",") { - pManager.AddPeer(strings.Trim(peer, " \t\r\n")) - } - } - - go func() { - listener, err := net.Listen("tcp", *instanceListen) - if err != nil { - panic(err) - } - - fmt.Println("Listening on", listener.Addr()) - - for { - conn, err := listener.Accept() - if err != nil { - logrus.WithError(err).Error("listener.Accept failed") - continue - } - - port, err := pRouter.Connect( - conn, - pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), - ) - if err != nil { - logrus.WithError(err).Error("pSwitch.Connect failed") - continue - } - - fmt.Println("Inbound connection", conn.RemoteAddr(), "is connected to port", port) - } - }() - - federation := conn.CreateFederationClient(base, pQUIC) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsComponent := roomserver.NewInternalAPI(base) - rsAPI := rsComponent - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) - - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - - rsComponent.SetFederationAPI(fsAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, pQUIC), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - monolith.AddAllPublicRoutes(base) - - wsUpgrader := websocket.Upgrader{ - CheckOrigin: func(_ *http.Request) bool { - return true - }, - } - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - c, err := wsUpgrader.Upgrade(w, r, nil) - if err != nil { - logrus.WithError(err).Error("Failed to upgrade WebSocket connection") - return - } - conn := conn.WrapWebSocketConn(c) - if _, err = pRouter.Connect( - conn, - pineconeRouter.ConnectionZone("websocket"), - pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), - ); err != nil { - logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") - } - }) - httpRouter.HandleFunc("/pinecone", pRouter.ManholeHandler) - embed.Embed(httpRouter, *instancePort, "Pinecone Demo") - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := pQUIC.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - httpServer := &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 60 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: pMux, - } - - go func() { - pubkey := pRouter.PublicKey() - logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) - logrus.Fatal(httpServer.Serve(pQUIC.Protocol("matrix"))) - }() - go func() { - httpBindAddr := fmt.Sprintf(":%d", *instancePort) - logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) - }() - - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: - case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - case pineconeEvents.BandwidthReport: - default: - } - } - }(pineconeEventChannel) - - base.WaitForShutdown() -} diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go deleted file mode 100644 index 0ac705cc1..000000000 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rooms - -import ( - "context" - "sync" - "time" - - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/defaults" - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeSessions "github.com/matrix-org/pinecone/sessions" -) - -type PineconeRoomProvider struct { - r *pineconeRouter.Router - s *pineconeSessions.Sessions - fedSender api.FederationInternalAPI - fedClient *gomatrixserverlib.FederationClient -} - -func NewPineconeRoomProvider( - r *pineconeRouter.Router, - s *pineconeSessions.Sessions, - fedSender api.FederationInternalAPI, - fedClient *gomatrixserverlib.FederationClient, -) *PineconeRoomProvider { - p := &PineconeRoomProvider{ - r: r, - s: s, - fedSender: fedSender, - fedClient: fedClient, - } - return p -} - -func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - list := map[gomatrixserverlib.ServerName]struct{}{} - for k := range defaults.DefaultServerNames { - list[k] = struct{}{} - } - for _, k := range p.r.Peers() { - list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} - } - return bulkFetchPublicRoomsFromServers( - context.Background(), p.fedClient, - gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, - ) -} - -// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. -// Returns a list of public rooms. -func bulkFetchPublicRoomsFromServers( - ctx context.Context, fedClient *gomatrixserverlib.FederationClient, - origin gomatrixserverlib.ServerName, - homeservers map[gomatrixserverlib.ServerName]struct{}, -) (publicRooms []gomatrixserverlib.PublicRoom) { - limit := 200 - // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. - // goroutines send rooms to this channel - roomCh := make(chan gomatrixserverlib.PublicRoom, int(limit)) - // signalling channel to tell goroutines to stop sending rooms and quit - done := make(chan bool) - // signalling to say when we can close the room channel - var wg sync.WaitGroup - wg.Add(len(homeservers)) - // concurrently query for public rooms - reqctx, reqcancel := context.WithTimeout(ctx, time.Second*5) - for hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { - defer wg.Done() - util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") - if err != nil { - util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( - "bulkFetchPublicRoomsFromServers: failed to query hs", - ) - return - } - for _, room := range fres.Chunk { - // atomically send a room or stop - select { - case roomCh <- room: - case <-done: - case <-reqctx.Done(): - util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Info("Interrupted whilst sending rooms") - return - } - } - }(hs) - } - - select { - case <-time.After(5 * time.Second): - default: - wg.Wait() - } - reqcancel() - close(done) - close(roomCh) - - for room := range roomCh { - publicRooms = append(publicRooms, room) - } - - return publicRooms -} diff --git a/cmd/dendrite-demo-pinecone/users/users.go b/cmd/dendrite-demo-pinecone/users/users.go deleted file mode 100644 index fc66bf299..000000000 --- a/cmd/dendrite-demo-pinecone/users/users.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package users - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "sync" - "time" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/defaults" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeSessions "github.com/matrix-org/pinecone/sessions" -) - -type PineconeUserProvider struct { - r *pineconeRouter.Router - s *pineconeSessions.Sessions - userAPI userapi.QuerySearchProfilesAPI - fedClient *gomatrixserverlib.FederationClient -} - -const PublicURL = "/_matrix/p2p/profiles" - -func NewPineconeUserProvider( - r *pineconeRouter.Router, - s *pineconeSessions.Sessions, - userAPI userapi.QuerySearchProfilesAPI, - fedClient *gomatrixserverlib.FederationClient, -) *PineconeUserProvider { - p := &PineconeUserProvider{ - r: r, - s: s, - userAPI: userAPI, - fedClient: fedClient, - } - return p -} - -func (p *PineconeUserProvider) FederatedUserProfiles(w http.ResponseWriter, r *http.Request) { - req := &userapi.QuerySearchProfilesRequest{Limit: 25} - res := &userapi.QuerySearchProfilesResponse{} - if err := clienthttputil.UnmarshalJSONRequest(r, &req); err != nil { - w.WriteHeader(400) - return - } - if err := p.userAPI.QuerySearchProfiles(r.Context(), req, res); err != nil { - w.WriteHeader(400) - return - } - j, err := json.Marshal(res) - if err != nil { - w.WriteHeader(400) - return - } - w.WriteHeader(200) - _, _ = w.Write(j) -} - -func (p *PineconeUserProvider) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { - list := map[gomatrixserverlib.ServerName]struct{}{} - for k := range defaults.DefaultServerNames { - list[k] = struct{}{} - } - for _, k := range p.r.Peers() { - list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} - } - res.Profiles = bulkFetchUserDirectoriesFromServers(context.Background(), req, p.fedClient, list) - return nil -} - -// bulkFetchUserDirectoriesFromServers fetches users from the list of homeservers. -// Returns a list of user profiles. -func bulkFetchUserDirectoriesFromServers( - ctx context.Context, req *userapi.QuerySearchProfilesRequest, - fedClient *gomatrixserverlib.FederationClient, - homeservers map[gomatrixserverlib.ServerName]struct{}, -) (profiles []authtypes.Profile) { - jsonBody, err := json.Marshal(req) - if err != nil { - return nil - } - - limit := 200 - // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. - // goroutines send rooms to this channel - profileCh := make(chan authtypes.Profile, int(limit)) - // signalling channel to tell goroutines to stop sending rooms and quit - done := make(chan bool) - // signalling to say when we can close the room channel - var wg sync.WaitGroup - wg.Add(len(homeservers)) - // concurrently query for public rooms - reqctx, reqcancel := context.WithTimeout(ctx, time.Second*5) - for hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { - defer wg.Done() - util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for users") - - jsonBodyReader := bytes.NewBuffer(jsonBody) - httpReq, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("matrix://%s%s", homeserverDomain, PublicURL), jsonBodyReader) - if err != nil { - util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( - "bulkFetchUserDirectoriesFromServers: failed to create request", - ) - } - res := &userapi.QuerySearchProfilesResponse{} - if err = fedClient.DoRequestAndParseResponse(reqctx, httpReq, res); err != nil { - util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( - "bulkFetchUserDirectoriesFromServers: failed to query hs", - ) - return - } - for _, profile := range res.Profiles { - profile.ServerName = string(homeserverDomain) - // atomically send a room or stop - select { - case profileCh <- profile: - case <-done: - case <-reqctx.Done(): - util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Info("Interrupted whilst sending profiles") - return - } - } - }(hs) - } - - select { - case <-time.After(5 * time.Second): - default: - wg.Wait() - } - reqcancel() - close(done) - close(profileCh) - - for profile := range profileCh { - profiles = append(profiles, profile) - } - - return profiles -} diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 1226496c3..d759c6a73 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -39,7 +39,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -117,8 +116,8 @@ func main() { cfg = setup.ParseFlags(true) } else { cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.PrivateKey = sk cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(*instanceDir, *instanceName)) @@ -143,7 +142,8 @@ func main() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg) + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen) @@ -156,15 +156,11 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - - rsComponent := roomserver.NewInternalAPI( + rsAPI := roomserver.NewInternalAPI( base, ) - rsAPI := rsComponent - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -172,7 +168,7 @@ func main() { base, federation, rsAPI, base.Caches, keyRing, true, ) - rsComponent.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(fsAPI, keyRing) monolith := setup.Monolith{ Config: base.Cfg, @@ -184,7 +180,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), @@ -195,9 +190,10 @@ func main() { } httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo") yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go deleted file mode 100644 index 62e004474..000000000 --- a/cmd/dendrite-monolith-server/main.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "log" - "os" - - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs" - "github.com/matrix-org/dendrite/userapi" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" -) - -var ( - httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") - httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") - apiBindAddr = flag.String("api-bind-address", "localhost:18008", "The HTTP listening port for the internal HTTP APIs (if -api is enabled)") - certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") - keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") - enableHTTPAPIs = flag.Bool("api", false, "Use HTTP APIs instead of short-circuiting (warning: exposes API endpoints!)") - traceInternal = os.Getenv("DENDRITE_TRACE_INTERNAL") == "1" -) - -func main() { - cfg := setup.ParseFlags(true) - httpAddr := config.HTTPAddress("http://" + *httpBindAddr) - for _, logging := range cfg.Logging { - if logging.Type == "std" { - level, err := logrus.ParseLevel(logging.Level) - if err != nil { - log.Fatal(err) - } - logrus.SetLevel(level) - logrus.SetFormatter(&logrus.JSONFormatter{}) - } - } - httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) - httpAPIAddr := httpAddr - options := []basepkg.BaseDendriteOptions{} - if *enableHTTPAPIs { - logrus.Warnf("DANGER! The -api option is enabled, exposing internal APIs on %q!", *apiBindAddr) - httpAPIAddr = config.HTTPAddress("http://" + *apiBindAddr) - // If the HTTP APIs are enabled then we need to update the Listen - // statements in the configuration so that we know where to find - // the API endpoints. They'll listen on the same port as the monolith - // itself. - cfg.AppServiceAPI.InternalAPI.Connect = httpAPIAddr - cfg.ClientAPI.InternalAPI.Connect = httpAPIAddr - cfg.FederationAPI.InternalAPI.Connect = httpAPIAddr - cfg.KeyServer.InternalAPI.Connect = httpAPIAddr - cfg.MediaAPI.InternalAPI.Connect = httpAPIAddr - cfg.RoomServer.InternalAPI.Connect = httpAPIAddr - cfg.SyncAPI.InternalAPI.Connect = httpAPIAddr - cfg.UserAPI.InternalAPI.Connect = httpAPIAddr - options = append(options, basepkg.UseHTTPAPIs) - } - - base := basepkg.NewBaseDendrite(cfg, "Monolith", options...) - defer base.Close() // nolint: errcheck - - federation := base.CreateFederationClient() - - rsImpl := roomserver.NewInternalAPI(base) - // call functions directly on the impl unless running in HTTP mode - rsAPI := rsImpl - if base.UseHTTPAPIs { - roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl) - rsAPI = base.RoomserverHTTPClient() - } - if traceInternal { - rsAPI = &api.RoomserverInternalAPITrace{ - Impl: rsAPI, - } - } - - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, nil, false, - ) - fsImplAPI := fsAPI - if base.UseHTTPAPIs { - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) - fsAPI = base.FederationAPIHTTPClient() - } - keyRing := fsAPI.KeyRing() - - keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - keyAPI := keyImpl - if base.UseHTTPAPIs { - keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI) - keyAPI = base.KeyServerHTTPClient() - } - - pgClient := base.PushGatewayHTTPClient() - userImpl := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) - userAPI := userImpl - if base.UseHTTPAPIs { - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) - userAPI = base.UserAPIClient() - } - if traceInternal { - userAPI = &uapi.UserInternalAPITrace{ - Impl: userAPI, - } - } - - // TODO: This should use userAPI, not userImpl, but the appservice setup races with - // the listeners and panics at startup if it tries to create appservice accounts - // before the listeners are up. - asAPI := appservice.NewInternalAPI(base, userImpl, rsAPI) - if base.UseHTTPAPIs { - appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) - asAPI = base.AppserviceHTTPClient() - } - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this - // dependency. Other components also need updating after their dependencies are up. - rsImpl.SetFederationAPI(fsAPI, keyRing) - rsImpl.SetAppserviceAPI(asAPI) - rsImpl.SetUserAPI(userAPI) - keyImpl.SetUserAPI(userAPI) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: base.CreateClient(), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - // always use the concrete impl here even in -http mode because adding public routes - // must be done on the concrete impl not an HTTP client else fedapi will call itself - FederationAPI: fsImplAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - } - monolith.AddAllPublicRoutes(base) - - if len(base.Cfg.MSCs.MSCs) > 0 { - if err := mscs.Enable(base, &monolith); err != nil { - logrus.WithError(err).Fatalf("Failed to enable MSCs") - } - } - - // Expose the matrix APIs directly rather than putting them under a /api path. - go func() { - base.SetupAndServeHTTP( - httpAPIAddr, // internal API - httpAddr, // external API - nil, nil, // TLS settings - ) - }() - // Handle HTTPS if certificate and key are provided - if *certFile != "" && *keyFile != "" { - go func() { - base.SetupAndServeHTTP( - basepkg.NoListener, // internal API - httpsAddr, // external API - certFile, keyFile, // TLS settings - ) - }() - } - - // We want to block forever to let the HTTP and HTTPS handler serve the APIs - base.WaitForShutdown() -} diff --git a/cmd/dendrite-polylith-multi/main.go b/cmd/dendrite-polylith-multi/main.go deleted file mode 100644 index c6a560b19..000000000 --- a/cmd/dendrite-polylith-multi/main.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "os" - "strings" - - "github.com/matrix-org/dendrite/cmd/dendrite-polylith-multi/personalities" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/sirupsen/logrus" -) - -type entrypoint func(base *base.BaseDendrite, cfg *config.Dendrite) - -func main() { - cfg := setup.ParseFlags(false) - - component := "" - if flag.NFlag() > 0 { - component = flag.Arg(0) // ./dendrite-polylith-multi --config=... clientapi - } else if len(os.Args) > 1 { - component = os.Args[1] // ./dendrite-polylith-multi clientapi - } - - components := map[string]entrypoint{ - "appservice": personalities.Appservice, - "clientapi": personalities.ClientAPI, - "federationapi": personalities.FederationAPI, - "keyserver": personalities.KeyServer, - "mediaapi": personalities.MediaAPI, - "roomserver": personalities.RoomServer, - "syncapi": personalities.SyncAPI, - "userapi": personalities.UserAPI, - } - - start, ok := components[component] - if !ok { - if component == "" { - logrus.Errorf("No component specified") - logrus.Info("The first argument on the command line must be the name of the component to run") - } else { - logrus.Errorf("Unknown component %q specified", component) - } - - var list []string - for c := range components { - list = append(list, c) - } - logrus.Infof("Valid components: %s", strings.Join(list, ", ")) - - os.Exit(1) - } - - logrus.Infof("Starting %q component", component) - - base := base.NewBaseDendrite(cfg, component, base.PolylithMode) // TODO - defer base.Close() // nolint: errcheck - - go start(base, cfg) - base.WaitForShutdown() -} diff --git a/cmd/dendrite-polylith-multi/personalities/appservice.go b/cmd/dendrite-polylith-multi/personalities/appservice.go deleted file mode 100644 index 4f74434a4..000000000 --- a/cmd/dendrite-polylith-multi/personalities/appservice.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/setup/base" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func Appservice(base *base.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - rsAPI := base.RoomserverHTTPClient() - - intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - appservice.AddInternalRoutes(base.InternalAPIMux, intAPI) - - base.SetupAndServeHTTP( - base.Cfg.AppServiceAPI.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/clientapi.go b/cmd/dendrite-polylith-multi/personalities/clientapi.go deleted file mode 100644 index a5d69d07c..000000000 --- a/cmd/dendrite-polylith-multi/personalities/clientapi.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/internal/transactions" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - federation := base.CreateFederationClient() - - asQuery := base.AppserviceHTTPClient() - rsAPI := base.RoomserverHTTPClient() - fsAPI := base.FederationAPIHTTPClient() - userAPI := base.UserAPIClient() - keyAPI := base.KeyServerHTTPClient() - - clientapi.AddPublicRoutes( - base, federation, rsAPI, asQuery, - transactions.New(), fsAPI, userAPI, userAPI, - keyAPI, nil, - ) - - base.SetupAndServeHTTP( - base.Cfg.ClientAPI.InternalAPI.Listen, - base.Cfg.ClientAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go deleted file mode 100644 index 6377ce9e3..000000000 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/federationapi" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func FederationAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - federation := base.CreateFederationClient() - rsAPI := base.RoomserverHTTPClient() - keyAPI := base.KeyServerHTTPClient() - fsAPI := federationapi.NewInternalAPI(base, federation, rsAPI, base.Caches, nil, true) - keyRing := fsAPI.KeyRing() - - federationapi.AddPublicRoutes( - base, - userAPI, federation, keyRing, - rsAPI, fsAPI, keyAPI, nil, - ) - - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) - - base.SetupAndServeHTTP( - base.Cfg.FederationAPI.InternalAPI.Listen, - base.Cfg.FederationAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go deleted file mode 100644 index f8aa57b86..000000000 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/keyserver" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - fsAPI := base.FederationAPIHTTPClient() - intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - intAPI.SetUserAPI(base.UserAPIClient()) - - keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) - - base.SetupAndServeHTTP( - base.Cfg.KeyServer.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/mediaapi.go b/cmd/dendrite-polylith-multi/personalities/mediaapi.go deleted file mode 100644 index 69d5fd5a8..000000000 --- a/cmd/dendrite-polylith-multi/personalities/mediaapi.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/mediaapi" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func MediaAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - client := base.CreateClient() - - mediaapi.AddPublicRoutes( - base, userAPI, client, - ) - - base.SetupAndServeHTTP( - base.Cfg.MediaAPI.InternalAPI.Listen, - base.Cfg.MediaAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/roomserver.go b/cmd/dendrite-polylith-multi/personalities/roomserver.go deleted file mode 100644 index 1deb51ce0..000000000 --- a/cmd/dendrite-polylith-multi/personalities/roomserver.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - "github.com/matrix-org/dendrite/roomserver" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func RoomServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - asAPI := base.AppserviceHTTPClient() - fsAPI := base.FederationAPIHTTPClient() - rsAPI := roomserver.NewInternalAPI(base) - rsAPI.SetFederationAPI(fsAPI, fsAPI.KeyRing()) - rsAPI.SetAppserviceAPI(asAPI) - roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) - - base.SetupAndServeHTTP( - base.Cfg.RoomServer.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/syncapi.go b/cmd/dendrite-polylith-multi/personalities/syncapi.go deleted file mode 100644 index 41637fe1d..000000000 --- a/cmd/dendrite-polylith-multi/personalities/syncapi.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/syncapi" -) - -func SyncAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := base.UserAPIClient() - - rsAPI := base.RoomserverHTTPClient() - - syncapi.AddPublicRoutes( - base, - userAPI, rsAPI, - base.KeyServerHTTPClient(), - ) - - base.SetupAndServeHTTP( - base.Cfg.SyncAPI.InternalAPI.Listen, - base.Cfg.SyncAPI.ExternalAPI.Listen, - nil, nil, - ) -} diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go deleted file mode 100644 index 3fe5a43d7..000000000 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package personalities - -import ( - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" -) - -func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - userAPI := userapi.NewInternalAPI( - base, &cfg.UserAPI, cfg.Derived.ApplicationServices, - base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), - base.PushGatewayHTTPClient(), - ) - - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) - - base.SetupAndServeHTTP( - base.Cfg.UserAPI.InternalAPI.Listen, // internal listener - basepkg.NoListener, // external listener - nil, nil, - ) -} diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 75446d18c..174a80a3e 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "io" + "io/ioutil" "log" "net/http" "os" @@ -44,6 +45,10 @@ var ( const HEAD = "HEAD" +// The binary was renamed after v0.11.1, so everything after that should use the new name +var binaryChangeVersion, _ = semver.NewVersion("v0.11.1") +var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD" + // Embed the Dockerfile to use when building dendrite versions. // We cannot use the dockerfile associated with the repo with each version sadly due to changes in // Docker versions. Specifically, earlier Dendrite versions are incompatible with newer Docker clients @@ -53,14 +58,16 @@ const HEAD = "HEAD" const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build +ARG BINARY # Copy the build context to the repo as this is the right dendrite code. This is different to the # Complement Dockerfile which wgets a branch. COPY . . -RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/${BINARY} RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -86,24 +93,27 @@ done \n\ \n\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ -./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +./${BINARY} --really-enable-open-registration ${PARAMS} || ./${BINARY} ${PARAMS} \n\ ' > run_dendrite.sh && chmod +x run_dendrite.sh ENV SERVER_NAME=localhost +ENV BINARY=dendrite EXPOSE 8008 8448 -CMD /build/run_dendrite.sh ` +CMD /build/run_dendrite.sh` const DockerfileSQLite = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build +ARG BINARY # Copy the build context to the repo as this is the right dendrite code. This is different to the # Complement Dockerfile which wgets a branch. COPY . . -RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/${BINARY} RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -115,10 +125,11 @@ RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postg RUN echo '\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ -./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +./${BINARY} --really-enable-open-registration ${PARAMS} || ./${BINARY} ${PARAMS} \n\ ' > run_dendrite.sh && chmod +x run_dendrite.sh ENV SERVER_NAME=localhost +ENV BINARY=dendrite EXPOSE 8008 8448 CMD /build/run_dendrite.sh ` @@ -179,7 +190,7 @@ func downloadArchive(cli *http.Client, tmpDir, archiveURL string, dockerfile []b } // buildDendrite builds Dendrite on the branchOrTagName given. Returns the image ID or an error -func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, branchOrTagName string) (string, error) { +func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir string, branchOrTagName, binary string) (string, error) { var tarball *bytes.Buffer var err error // If a custom HEAD location is given, use that, else pull from github. Mostly useful for CI @@ -213,6 +224,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, log.Printf("%s: Building version %s\n", branchOrTagName, branchOrTagName) res, err := dockerClient.ImageBuild(context.Background(), tarball, types.ImageBuildOptions{ Tags: []string{"dendrite-upgrade"}, + BuildArgs: map[string]*string{ + "BINARY": &binary, + }, }) if err != nil { return "", fmt.Errorf("failed to start building image: %s", err) @@ -269,7 +283,7 @@ func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Ve return semVers, nil } -func calculateVersions(cli *http.Client, from, to string, direct bool) []string { +func calculateVersions(cli *http.Client, from, to string, direct bool) []*semver.Version { semvers, err := getAndSortVersionsFromGithub(cli) if err != nil { log.Fatalf("failed to collect semvers from github: %s", err) @@ -317,28 +331,25 @@ func calculateVersions(cli *http.Client, from, to string, direct bool) []string } semvers = semvers[:i+1] } - var versions []string - for _, sv := range semvers { - versions = append(versions, sv.Original()) - } + if to == HEAD { - versions = append(versions, HEAD) + semvers = append(semvers, latest) } if direct { - versions = []string{versions[0], versions[len(versions)-1]} + semvers = []*semver.Version{semvers[0], semvers[len(semvers)-1]} } - return versions + return semvers } -func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, baseTempDir string, concurrency int, branchOrTagNames []string) map[string]string { +func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, baseTempDir string, concurrency int, versions []*semver.Version) map[string]string { // concurrently build all versions, this can be done in any order. The mutex protects the map branchToImageID := make(map[string]string) var mu sync.Mutex var wg sync.WaitGroup wg.Add(concurrency) - ch := make(chan string, len(branchOrTagNames)) - for _, branchName := range branchOrTagNames { + ch := make(chan *semver.Version, len(versions)) + for _, branchName := range versions { ch <- branchName } close(ch) @@ -346,11 +357,13 @@ func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, b for i := 0; i < concurrency; i++ { go func() { defer wg.Done() - for branchName := range ch { + for version := range ch { + branchName, binary := versionToBranchAndBinary(version) + log.Printf("Building version %s with binary %s", branchName, binary) tmpDir := baseTempDir + alphaNumerics.ReplaceAllString(branchName, "") - imgID, err := buildDendrite(httpClient, dockerClient, tmpDir, branchName) + imgID, err := buildDendrite(httpClient, dockerClient, tmpDir, branchName, binary) if err != nil { - log.Fatalf("%s: failed to build dendrite image: %s", branchName, err) + log.Fatalf("%s: failed to build dendrite image: %s", version, err) } mu.Lock() branchToImageID[branchName] = imgID @@ -362,13 +375,14 @@ func buildDendriteImages(httpClient *http.Client, dockerClient *client.Client, b return branchToImageID } -func runImage(dockerClient *client.Client, volumeName, version, imageID string) (csAPIURL, containerID string, err error) { - log.Printf("%s: running image %s\n", version, imageID) +func runImage(dockerClient *client.Client, volumeName string, branchNameToImageID map[string]string, version *semver.Version) (csAPIURL, containerID string, err error) { + branchName, binary := versionToBranchAndBinary(version) + imageID := branchNameToImageID[branchName] ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() body, err := dockerClient.ContainerCreate(ctx, &container.Config{ Image: imageID, - Env: []string{"SERVER_NAME=hs1"}, + Env: []string{"SERVER_NAME=hs1", fmt.Sprintf("BINARY=%s", binary)}, Labels: map[string]string{ dendriteUpgradeTestLabel: "yes", }, @@ -381,7 +395,7 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) Target: "/var/lib/postgresql/9.6/main", }, }, - }, nil, nil, "dendrite_upgrade_test_"+version) + }, nil, nil, "dendrite_upgrade_test_"+branchName) if err != nil { return "", "", fmt.Errorf("failed to ContainerCreate: %s", err) } @@ -448,8 +462,8 @@ func destroyContainer(dockerClient *client.Client, containerID string) { } } -func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchToImageID map[string]string) error { - csAPIURL, containerID, err := runImage(dockerClient, volumeName, v, branchToImageID[v]) +func loadAndRunTests(dockerClient *client.Client, volumeName string, v *semver.Version, branchToImageID map[string]string) error { + csAPIURL, containerID, err := runImage(dockerClient, volumeName, branchToImageID, v) if err != nil { return fmt.Errorf("failed to run container for branch %v: %v", v, err) } @@ -458,12 +472,65 @@ func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchTo if err = runTests(csAPIURL, v); err != nil { return fmt.Errorf("failed to run tests on version %s: %s", v, err) } + + err = testCreateAccount(dockerClient, v, containerID) + if err != nil { + return err + } return nil } -func verifyTests(dockerClient *client.Client, volumeName string, versions []string, branchToImageID map[string]string) error { +// test that create-account is working +func testCreateAccount(dockerClient *client.Client, version *semver.Version, containerID string) error { + branchName, _ := versionToBranchAndBinary(version) + createUser := strings.ToLower("createaccountuser-" + branchName) + log.Printf("%s: Creating account %s with create-account\n", branchName, createUser) + + respID, err := dockerClient.ContainerExecCreate(context.Background(), containerID, types.ExecConfig{ + AttachStderr: true, + AttachStdout: true, + Cmd: []string{ + "/build/create-account", + "-username", createUser, + "-password", "someRandomPassword", + }, + }) + if err != nil { + return fmt.Errorf("failed to ContainerExecCreate: %w", err) + } + + response, err := dockerClient.ContainerExecAttach(context.Background(), respID.ID, types.ExecStartCheck{}) + if err != nil { + return fmt.Errorf("failed to attach to container: %w", err) + } + defer response.Close() + + data, err := ioutil.ReadAll(response.Reader) + if err != nil { + return err + } + + if !bytes.Contains(data, []byte("AccessToken")) { + return fmt.Errorf("failed to create-account: %s", string(data)) + } + return nil +} + +func versionToBranchAndBinary(version *semver.Version) (branchName, binary string) { + binary = "dendrite-monolith-server" + branchName = version.Original() + if version.GreaterThan(binaryChangeVersion) { + binary = "dendrite" + if version.Equal(latest) { + branchName = HEAD + } + } + return +} + +func verifyTests(dockerClient *client.Client, volumeName string, versions []*semver.Version, branchToImageID map[string]string) error { lastVer := versions[len(versions)-1] - csAPIURL, containerID, err := runImage(dockerClient, volumeName, lastVer, branchToImageID[lastVer]) + csAPIURL, containerID, err := runImage(dockerClient, volumeName, branchToImageID, lastVer) if err != nil { return fmt.Errorf("failed to run container for branch %v: %v", lastVer, err) } diff --git a/cmd/dendrite-upgrade-tests/tests.go b/cmd/dendrite-upgrade-tests/tests.go index 5c9589df2..03438bd4d 100644 --- a/cmd/dendrite-upgrade-tests/tests.go +++ b/cmd/dendrite-upgrade-tests/tests.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/Masterminds/semver/v3" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -22,7 +23,8 @@ type user struct { // - register alice and bob with branch name muxed into the localpart // - create a DM room for the 2 users and exchange messages // - create/join a public #global room and exchange messages -func runTests(baseURL, branchName string) error { +func runTests(baseURL string, v *semver.Version) error { + branchName, _ := versionToBranchAndBinary(v) // register 2 users users := []user{ { @@ -164,15 +166,16 @@ func runTests(baseURL, branchName string) error { } // verifyTestsRan checks that the HS has the right rooms/messages -func verifyTestsRan(baseURL string, branchNames []string) error { +func verifyTestsRan(baseURL string, versions []*semver.Version) error { log.Println("Verifying tests....") // check we can login as all users var resp *gomatrix.RespLogin - for _, branchName := range branchNames { + for _, version := range versions { client, err := gomatrix.NewClient(baseURL, "", "") if err != nil { return err } + branchName, _ := versionToBranchAndBinary(version) userLocalparts := []string{ "alice" + branchName, "bob" + branchName, @@ -224,7 +227,7 @@ func verifyTestsRan(baseURL string, branchNames []string) error { msgCount += 1 } } - wantMsgCount := len(branchNames) * 4 + wantMsgCount := len(versions) * 4 if msgCount != wantMsgCount { return fmt.Errorf("got %d messages in global room, want %d", msgCount, wantMsgCount) } diff --git a/cmd/dendrite-monolith-server/Dockerfile.dev b/cmd/dendrite/Dockerfile.dev similarity index 68% rename from cmd/dendrite-monolith-server/Dockerfile.dev rename to cmd/dendrite/Dockerfile.dev index 7fbf6c667..281efa69c 100644 --- a/cmd/dendrite-monolith-server/Dockerfile.dev +++ b/cmd/dendrite/Dockerfile.dev @@ -5,4 +5,4 @@ COPY dendrite-monolith-server /usr/bin/ VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] +ENTRYPOINT ["/usr/bin/dendrite"] \ No newline at end of file diff --git a/cmd/dendrite-monolith-server/build_dev.sh b/cmd/dendrite/build_dev.sh old mode 100755 new mode 100644 similarity index 100% rename from cmd/dendrite-monolith-server/build_dev.sh rename to cmd/dendrite/build_dev.sh diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go new file mode 100644 index 000000000..1ae348cfa --- /dev/null +++ b/cmd/dendrite/main.go @@ -0,0 +1,125 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "io/fs" + + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup" + basepkg "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/mscs" + "github.com/matrix-org/dendrite/userapi" +) + +var ( + unixSocket = flag.String("unix-socket", "", + "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", + ) + unixSocketPermission = flag.Int("unix-socket-permission", 0755, + "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server", + ) + httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") + httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") + certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") + keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") +) + +func main() { + cfg := setup.ParseFlags(true) + httpAddr := config.ServerAddress{} + httpsAddr := config.ServerAddress{} + if *unixSocket == "" { + http, err := config.HTTPAddress("http://" + *httpBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse http address") + } + httpAddr = http + https, err := config.HTTPAddress("https://" + *httpsBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse https address") + } + httpsAddr = https + } else { + httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission)) + } + + options := []basepkg.BaseDendriteOptions{} + + base := basepkg.NewBaseDendrite(cfg, options...) + defer base.Close() // nolint: errcheck + + federation := base.CreateFederationClient() + + rsAPI := roomserver.NewInternalAPI(base) + + fsAPI := federationapi.NewInternalAPI( + base, federation, rsAPI, base.Caches, nil, false, + ) + + keyRing := fsAPI.KeyRing() + + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) + + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this + // dependency. Other components also need updating after their dependencies are up. + rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetAppserviceAPI(asAPI) + rsAPI.SetUserAPI(userAPI) + + monolith := setup.Monolith{ + Config: base.Cfg, + Client: base.CreateClient(), + FedClient: federation, + KeyRing: keyRing, + + AppserviceAPI: asAPI, + // always use the concrete impl here even in -http mode because adding public routes + // must be done on the concrete impl not an HTTP client else fedapi will call itself + FederationAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + } + monolith.AddAllPublicRoutes(base) + + if len(base.Cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } + } + + // Expose the matrix APIs directly rather than putting them under a /api path. + go func() { + base.SetupAndServeHTTP(httpAddr, nil, nil) + }() + // Handle HTTPS if certificate and key are provided + if *unixSocket == "" && *certFile != "" && *keyFile != "" { + go func() { + base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) + }() + } + + // We want to block forever to let the HTTP and HTTPS handler serve the APIs + base.WaitForShutdown() +} diff --git a/cmd/dendrite-monolith-server/main_test.go b/cmd/dendrite/main_test.go similarity index 93% rename from cmd/dendrite-monolith-server/main_test.go rename to cmd/dendrite/main_test.go index efa1a926c..d51bc7434 100644 --- a/cmd/dendrite-monolith-server/main_test.go +++ b/cmd/dendrite/main_test.go @@ -9,7 +9,7 @@ import ( ) // This is an instrumented main, used when running integration tests (sytest) with code coverage. -// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite // Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml // Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html // Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 33b18c471..86b302346 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -18,7 +18,6 @@ func main() { dbURI := flag.String("db", "", "The DB URI to use for all components (PostgreSQL only)") dirPath := flag.String("dir", "./", "The folder to use for paths (like SQLite databases, media storage)") normalise := flag.String("normalise", "", "Normalise an existing configuration file by adding new/missing options and defaults") - polylith := flag.Bool("polylith", false, "Generate a config that makes sense for polylith deployments") flag.Parse() var cfg *config.Dendrite @@ -27,14 +26,14 @@ func main() { Version: config.Version, } cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: !*polylith, + Generate: true, + SingleDatabase: true, }) if *serverName != "" { cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) } uri := config.DataSource(*dbURI) - if *polylith || uri.IsSQLite() || uri == "" { + if uri.IsSQLite() || uri == "" { for name, db := range map[string]*config.DatabaseOptions{ "federationapi": &cfg.FederationAPI.Database, "keyserver": &cfg.KeyServer.Database, @@ -54,6 +53,9 @@ func main() { } else { cfg.Global.DatabaseOptions.ConnectionString = uri } + cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) + cfg.Global.JetStream.StoragePath = config.Path(*dirPath) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(*dirPath, "searchindex")) cfg.Logging = []config.LogrusHook{ { Type: "file", @@ -67,6 +69,7 @@ func main() { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationAPI.DisableTLSValidation = false + cfg.FederationAPI.DisableHTTPKeepalives = true // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) @@ -92,7 +95,7 @@ func main() { } } else { var err error - if cfg, err = config.Load(*normalise, !*polylith); err != nil { + if cfg, err = config.Load(*normalise); err != nil { panic(err) } } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index f8bb130c7..a9cc80cb7 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -40,7 +40,7 @@ func main() { Level: "error", }) cfg.ClientAPI.RegistrationDisabled = true - base := base.NewBaseDendrite(cfg, "ResolveState", base.DisableMetrics) + base := base.NewBaseDendrite(cfg, base.DisableMetrics) args := flag.Args() fmt.Println("Room version", *roomVersion) @@ -62,9 +62,10 @@ func main() { panic(err) } - stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), - }) + } + stateres := state.NewStateResolution(roomserverDB, roomInfo) if *difference { if len(snapshotNIDs) != 2 { @@ -87,7 +88,7 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -145,7 +146,7 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -165,7 +166,7 @@ func main() { } fmt.Println("Fetching", len(authEventIDs), "auth events") - authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs) + authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs) if err != nil { panic(err) } diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml deleted file mode 100644 index ecc3f4051..000000000 --- a/dendrite-sample.polylith.yaml +++ /dev/null @@ -1,417 +0,0 @@ -# This is the Dendrite configuration file. -# -# The configuration is split up into sections - each Dendrite component has a -# configuration section, in addition to the "global" section which applies to -# all components. - -# The version of the configuration file. -version: 2 - -# Global Matrix configuration. This configuration applies to all components. -global: - # The domain name of this homeserver. - server_name: localhost - - # The path to the signing private key file, used to sign requests and events. - # Note that this is NOT the same private key as used for TLS! To generate a - # signing key, use "./bin/generate-keys --private-key matrix_key.pem". - private_key: matrix_key.pem - - # The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) - # to old signing keys that were formerly in use on this domain name. These - # keys will not be used for federation request or event signing, but will be - # provided to any other homeserver that asks when trying to verify old events. - old_private_keys: - # If the old private key file is available: - # - private_key: old_matrix_key.pem - # expired_at: 1601024554498 - # If only the public key (in base64 format) and key ID are known: - # - public_key: mn59Kxfdq9VziYHSBzI7+EDPDcBS2Xl7jeUdiiQcOnM= - # key_id: ed25519:mykeyid - # expired_at: 1601024554498 - - # How long a remote server can cache our server signing key before requesting it - # again. Increasing this number will reduce the number of requests made by other - # servers for our key but increases the period that a compromised key will be - # considered valid by other homeservers. - key_validity_period: 168h0m0s - - # Configuration for in-memory caches. Caches can often improve performance by - # keeping frequently accessed items (like events, identifiers etc.) in memory - # rather than having to read them from the database. - cache: - # The estimated maximum size for the global cache in bytes, or in terabytes, - # gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or - # 'kb' suffix is specified. Note that this is not a hard limit, nor is it a - # memory limit for the entire process. A cache that is too small may ultimately - # provide little or no benefit. - max_size_estimated: 1gb - - # The maximum amount of time that a cache entry can live for in memory before - # it will be evicted and/or refreshed from the database. Lower values result in - # easier admission of new cache entries but may also increase database load in - # comparison to higher values, so adjust conservatively. Higher values may make - # it harder for new items to make it into the cache, e.g. if new rooms suddenly - # become popular. - max_age: 1h - - # The server name to delegate server-server communications to, with optional port - # e.g. localhost:443 - well_known_server_name: "" - - # The server name to delegate client-server communications to, with optional port - # e.g. localhost:443 - well_known_client_name: "" - - # Lists of domains that the server will trust as identity servers to verify third - # party identifiers such as phone numbers and email addresses. - trusted_third_party_id_servers: - - matrix.org - - vector.im - - # Disables federation. Dendrite will not be able to communicate with other servers - # in the Matrix federation and the federation API will not be exposed. - disable_federation: false - - # Configures the handling of presence events. Inbound controls whether we receive - # presence events from other servers, outbound controls whether we send presence - # events for our local users to other servers. - presence: - enable_inbound: false - enable_outbound: false - - # Configures phone-home statistics reporting. These statistics contain the server - # name, number of active users and some information on your deployment config. - # We use this information to understand how Dendrite is being used in the wild. - report_stats: - enabled: false - endpoint: https://matrix.org/report-usage-stats/push - - # Server notices allows server admins to send messages to all users on the server. - server_notices: - enabled: false - # The local part, display name and avatar URL (as a mxc:// URL) for the user that - # will send the server notices. These are visible to all users on the deployment. - local_part: "_server" - display_name: "Server Alerts" - avatar_url: "" - # The room name to be used when sending server notices. This room name will - # appear in user clients. - room_name: "Server Alerts" - - # Configuration for NATS JetStream - jetstream: - # A list of NATS Server addresses to connect to. If none are specified, an - # internal NATS server will be started automatically when running Dendrite in - # monolith mode. For polylith deployments, it is required to specify the address - # of at least one NATS Server node. - addresses: - - hostname:4222 - - # Disable the validation of TLS certificates of NATS. This is - # not recommended in production since it may allow NATS traffic - # to be sent to an insecure endpoint. - disable_tls_validation: false - - # The prefix to use for stream names for this homeserver - really only useful - # if you are running more than one Dendrite server on the same NATS deployment. - topic_prefix: Dendrite - - # Configuration for Prometheus metric collection. - metrics: - enabled: false - basic_auth: - username: metrics - password: metrics - - # Optional DNS cache. The DNS cache may reduce the load on DNS servers if there - # is no local caching resolver available for use. - dns_cache: - enabled: false - cache_size: 256 - cache_lifetime: "5m" # 5 minutes; https://pkg.go.dev/time@master#ParseDuration - -# Configuration for the Appservice API. -app_service_api: - internal_api: - listen: http://[::]:7777 # The listen address for incoming API requests - connect: http://app_service_api:7777 # The connect address for other components to use - - # Disable the validation of TLS certificates of appservices. This is - # not recommended in production since it may allow appservice traffic - # to be sent to an insecure endpoint. - disable_tls_validation: false - - # Appservice configuration files to load into this homeserver. - config_files: - # - /path/to/appservice_registration.yaml - -# Configuration for the Client API. -client_api: - internal_api: - listen: http://[::]:7771 # The listen address for incoming API requests - connect: http://client_api:7771 # The connect address for other components to use - external_api: - listen: http://[::]:8071 - - # Prevents new users from being able to register on this homeserver, except when - # using the registration shared secret below. - registration_disabled: true - - # Prevents new guest accounts from being created. Guest registration is also - # disabled implicitly by setting 'registration_disabled' above. - guests_disabled: true - - # If set, allows registration by anyone who knows the shared secret, regardless - # of whether registration is otherwise disabled. - registration_shared_secret: "" - - # Whether to require reCAPTCHA for registration. If you have enabled registration - # then this is HIGHLY RECOMMENDED to reduce the risk of your homeserver being used - # for coordinated spam attacks. - enable_registration_captcha: false - - # Settings for ReCAPTCHA. - recaptcha_public_key: "" - recaptcha_private_key: "" - recaptcha_bypass_secret: "" - - # To use hcaptcha.com instead of ReCAPTCHA, set the following parameters, otherwise just keep them empty. - # recaptcha_siteverify_api: "https://hcaptcha.com/siteverify" - # recaptcha_api_js_url: "https://js.hcaptcha.com/1/api.js" - # recaptcha_form_field: "h-captcha-response" - # recaptcha_sitekey_class: "h-captcha" - - - # TURN server information that this homeserver should send to clients. - turn: - turn_user_lifetime: "5m" - turn_uris: - # - turn:turn.server.org?transport=udp - # - turn:turn.server.org?transport=tcp - turn_shared_secret: "" - # If your TURN server requires static credentials, then you will need to enter - # them here instead of supplying a shared secret. Note that these credentials - # will be visible to clients! - # turn_username: "" - # turn_password: "" - - # Settings for rate-limited endpoints. Rate limiting kicks in after the threshold - # number of "slots" have been taken by requests from a specific host. Each "slot" - # will be released after the cooloff time in milliseconds. Server administrators - # and appservice users are exempt from rate limiting by default. - rate_limiting: - enabled: true - threshold: 20 - cooloff_ms: 500 - exempt_user_ids: - # - "@user:domain.com" - -# Configuration for the Federation API. -federation_api: - internal_api: - listen: http://[::]:7772 # The listen address for incoming API requests - connect: http://federation_api:7772 # The connect address for other components to use - external_api: - listen: http://[::]:8072 - database: - connection_string: postgresql://username:password@hostname/dendrite_federationapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # How many times we will try to resend a failed transaction to a specific server. The - # backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds etc. Once - # the max retries are exceeded, Dendrite will no longer try to send transactions to - # that server until it comes back to life and connects to us again. - send_max_retries: 16 - - # Disable the validation of TLS certificates of remote federated homeservers. Do not - # enable this option in production as it presents a security risk! - disable_tls_validation: false - - # Disable HTTP keepalives, which also prevents connection reuse. Dendrite will typically - # keep HTTP connections open to remote hosts for 5 minutes as they can be reused much - # more quickly than opening new connections each time. Disabling keepalives will close - # HTTP connections immediately after a successful request but may result in more CPU and - # memory being used on TLS handshakes for each new connection instead. - disable_http_keepalives: false - - # Perspective keyservers to use as a backup when direct key fetches fail. This may - # be required to satisfy key requests for servers that are no longer online when - # joining some rooms. - key_perspectives: - - server_name: matrix.org - keys: - - key_id: ed25519:auto - public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw - - key_id: ed25519:a_RXGa - public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ - - # This option will control whether Dendrite will prefer to look up keys directly - # or whether it should try perspective servers first, using direct fetches as a - # last resort. - prefer_direct_fetch: false - -# Configuration for the Key Server (for end-to-end encryption). -key_server: - internal_api: - listen: http://[::]:7779 # The listen address for incoming API requests - connect: http://key_server:7779 # The connect address for other components to use - database: - connection_string: postgresql://username:password@hostname/dendrite_keyserver?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Media API. -media_api: - internal_api: - listen: http://[::]:7774 # The listen address for incoming API requests - connect: http://media_api:7774 # The connect address for other components to use - external_api: - listen: http://[::]:8074 - database: - connection_string: postgresql://username:password@hostname/dendrite_mediaapi?sslmode=disable - max_open_conns: 5 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # Storage path for uploaded media. May be relative or absolute. - base_path: ./media_store - - # The maximum allowed file size (in bytes) for media uploads to this homeserver - # (0 = unlimited). If using a reverse proxy, ensure it allows requests at least - #this large (e.g. the client_max_body_size setting in nginx). - max_file_size_bytes: 10485760 - - # Whether to dynamically generate thumbnails if needed. - dynamic_thumbnails: false - - # The maximum number of simultaneous thumbnail generators to run. - max_thumbnail_generators: 10 - - # A list of thumbnail sizes to be generated for media content. - thumbnail_sizes: - - width: 32 - height: 32 - method: crop - - width: 96 - height: 96 - method: crop - - width: 640 - height: 480 - method: scale - -# Configuration for enabling experimental MSCs on this homeserver. -mscs: - mscs: - # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) - # - msc2946 # (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) - database: - connection_string: postgresql://username:password@hostname/dendrite_mscs?sslmode=disable - max_open_conns: 5 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Room Server. -room_server: - internal_api: - listen: http://[::]:7770 # The listen address for incoming API requests - connect: http://room_server:7770 # The connect address for other components to use - database: - connection_string: postgresql://username:password@hostname/dendrite_roomserver?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - -# Configuration for the Sync API. -sync_api: - internal_api: - listen: http://[::]:7773 # The listen address for incoming API requests - connect: http://sync_api:7773 # The connect address for other components to use - external_api: - listen: http://[::]:8073 - database: - connection_string: postgresql://username:password@hostname/dendrite_syncapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # Configuration for the full-text search engine. - search: - # Whether or not search is enabled. - enabled: false - - # The path where the search index will be created in. - index_path: "./searchindex" - - # The language most likely to be used on the server - used when indexing, to - # ensure the returned results match expectations. A full list of possible languages - # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang - language: "en" - - # This option controls which HTTP header to inspect to find the real remote IP - # address of the client. This is likely required if Dendrite is running behind - # a reverse proxy server. - # real_ip_header: X-Real-IP - -# Configuration for the User API. -user_api: - internal_api: - listen: http://[::]:7781 # The listen address for incoming API requests - connect: http://user_api:7781 # The connect address for other components to use - account_database: - connection_string: postgresql://username:password@hostname/dendrite_userapi?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - - # The cost when hashing passwords on registration/login. Default: 10. Min: 4, Max: 31 - # See https://pkg.go.dev/golang.org/x/crypto/bcrypt for more information. - # Setting this lower makes registration/login consume less CPU resources at the cost - # of security should the database be compromised. Setting this higher makes registration/login - # consume more CPU resources but makes it harder to brute force password hashes. This value - # can be lowered if performing tests or on embedded Dendrite instances (e.g WASM builds). - bcrypt_cost: 10 - - # The length of time that a token issued for a relying party from - # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. - # The default lifetime is 3600000ms (60 minutes). - # openid_token_lifetime_ms: 3600000 - - # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. - # By default, any room aliases included in this list will be created as a publicly joinable room - # when the first user registers for the homeserver. If the room already exists, - # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. - # As Spaces are just rooms under the hood, Space aliases may also be used. - auto_join_rooms: - # - "#main:matrix.org" - -# Configuration for Opentracing. -# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on -# how this works and how to set it up. -tracing: - enabled: false - jaeger: - serviceName: "" - disabled: false - rpc_metrics: false - tags: [] - sampler: null - reporter: null - headers: null - baggage_restrictions: null - throttler: null - -# Logging configuration. The "std" logging type controls the logs being sent to -# stdout. The "file" logging type controls logs being written to a log folder on -# the disk. Supported log levels are "debug", "info", "warn", "error". -logging: - - type: std - level: info - - type: file - level: info - params: - path: ./logs diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.yaml similarity index 98% rename from dendrite-sample.monolith.yaml rename to dendrite-sample.yaml index d86e9da94..6b3ea74f2 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.yaml @@ -38,7 +38,7 @@ global: # Global database connection pool, for PostgreSQL monolith deployments only. If # this section is populated then you can omit the "database" blocks in all other - # sections. For polylith deployments, or monolith deployments using SQLite databases, + # sections. For monolith deployments using SQLite databases, # you must configure the "database" block for each component instead. database: connection_string: postgresql://username:password@hostname/dendrite?sslmode=disable @@ -95,7 +95,7 @@ global: # We use this information to understand how Dendrite is being used in the wild. report_stats: enabled: false - endpoint: https://matrix.org/report-usage-stats/push + endpoint: https://panopticon.matrix.org/push # Server notices allows server admins to send messages to all users on the server. server_notices: @@ -113,8 +113,7 @@ global: jetstream: # A list of NATS Server addresses to connect to. If none are specified, an # internal NATS server will be started automatically when running Dendrite in - # monolith mode. For polylith deployments, it is required to specify the address - # of at least one NATS Server node. + # monolith mode. addresses: # - localhost:4222 diff --git a/docs/FAQ.md b/docs/FAQ.md index ca72b151d..2899aa982 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -6,6 +6,12 @@ permalink: /faq # FAQ +## Why does Dendrite exist? + +Dendrite aims to provide a matrix compatible server that has low resource usage compared to [Synapse](https://github.com/matrix-org/synapse). +It also aims to provide more flexibility when scaling either up or down. +Dendrite's code is also very easy to hack on which makes it suitable for experimenting with new matrix features such as peer-to-peer. + ## Is Dendrite stable? Mostly, although there are still bugs and missing features. If you are a confident power user and you are happy to spend some time debugging things when they go wrong, then please try out Dendrite. If you are a community, organisation or business that demands stability and uptime, then Dendrite is not for you yet - please install Synapse instead. @@ -29,10 +35,9 @@ possible to migrate an existing Synapse deployment to Dendrite. No, Dendrite has a very different database schema to Synapse and the two are not interchangeable. -## Should I run a monolith or a polylith deployment? +## Can I configure which port Dendrite listens on? -Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines, but this is an advanced configuration which we don't -recommend. +Yes, use the cli flag `-http-bind-address`. ## I've installed Dendrite but federation isn't working @@ -42,6 +47,10 @@ Check the [Federation Tester](https://federationtester.matrix.org). You need at * A valid TLS certificate for that DNS name * Either DNS SRV records or well-known files +## Whenever I try to connect from Element it says unable to connect to homeserver + +Check that your dendrite instance is running. Otherwise this is most likely due to a reverse proxy misconfiguration. + ## Does Dendrite work with my favourite client? It should do, although we are aware of some minor issues: @@ -49,6 +58,10 @@ It should do, although we are aware of some minor issues: * **Element Android**: registration does not work, but logging in with an existing account does * **Hydrogen**: occasionally sync can fail due to gaps in the `since` parameter, but clearing the cache fixes this +## Is there a public instance of Dendrite I can try out? + +Use [dendrite.matrix.org](https://dendrite.matrix.org) which we officially support. + ## Does Dendrite support Space Summaries? Yes, [Space Summaries](https://github.com/matrix-org/matrix-spec-proposals/pull/2946) were merged into the Matrix Spec as of 2022-01-17 however, they are still treated as an MSC (Matrix Specification Change) in Dendrite. In order to enable Space Summaries in Dendrite, you must add the MSC to the MSC configuration section in the configuration YAML. If the MSC is not enabled, a user will typically see a perpetual loading icon on the summary page. See below for a demonstration of how to add to the Dendrite configuration: @@ -84,14 +97,46 @@ Remember to add the config file(s) to the `app_service_api` section of the confi Yes, you can do this by disabling federation - set `disable_federation` to `true` in the `global` section of the Dendrite configuration file. +## How can I migrate a room in order to change the internal ID? + +This can be done by performing a room upgrade. Use the command `/upgraderoom ` in Element to do this. + +## How do I reset somebody's password on my server? + +Use the admin endpoint [resetpassword](https://matrix-org.github.io/dendrite/administration/adminapi#post-_dendriteadminresetpassworduserid) + ## Should I use PostgreSQL or SQLite for my databases? Please use PostgreSQL wherever possible, especially if you are planning to run a homeserver that caters to more than a couple of users. +## What data needs to be kept if transferring/backing up Dendrite? + +The list of files that need to be stored is: +- matrix-key.pem +- dendrite.yaml +- the postgres or sqlite DB +- the media store +- the search index (although this can be regenerated) + +Note that this list may change / be out of date. We don't officially maintain instructions for migrations like this. + +## How can I prepare enough storage for media caches? + +This might be what you want: [matrix-media-repo](https://github.com/turt2live/matrix-media-repo) +We don't officially support this or any other dedicated media storage solutions. + +## Is there an upgrade guide for Dendrite? + +Run a newer docker image. We don't officially support deployments other than Docker. +Most of the time you should be able to just +- stop +- replace binary +- start + ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the -CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would +CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the CPU time is going. ## Dendrite is using a lot of RAM @@ -99,9 +144,13 @@ be a huge help too, as that will help us to understand where the CPU time is goi As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work at a given instant. However, if it is using more RAM than you expect for a long time, that's probably not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage -ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that +ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. +## Do I need to generate the self-signed certificate if I'm going to use a reverse proxy? + +No, if you already have a proper certificate from some provider, like Let's Encrypt, and use that on your reverse proxy, and the reverse proxy does TLS termination, then you’re good and can use HTTP to the dendrite process. + ## Dendrite is running out of PostgreSQL database connections You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode! diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index c7ba43711..a61786c1d 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -1,7 +1,7 @@ GEM remote: https://rubygems.org/ specs: - activesupport (6.0.5) + activesupport (6.0.6.1) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 0.7, < 2) minitest (~> 5.1) @@ -14,8 +14,8 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.6) - concurrent-ruby (1.1.10) + commonmarker (0.23.7) + concurrent-ruby (1.2.0) dnsruby (1.61.9) simpleidn (~> 0.1) em-websocket (0.5.3) @@ -229,11 +229,11 @@ GEM jekyll (>= 3.5, < 5.0) jekyll-feed (~> 0.9) jekyll-seo-tag (~> 2.1) - minitest (5.15.0) + minitest (5.17.0) multipart-post (2.1.1) - nokogiri (1.13.9-arm64-darwin) + nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.9-x86_64-linux) + nokogiri (1.13.10-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.0) + racc (1.6.1) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) @@ -265,13 +265,13 @@ GEM thread_safe (0.3.6) typhoeus (1.4.0) ethon (>= 0.9.0) - tzinfo (1.2.10) + tzinfo (1.2.11) thread_safe (~> 0.1) unf (0.1.4) unf_ext unf_ext (0.0.8.1) unicode-display_width (1.8.0) - zeitwerk (2.5.4) + zeitwerk (2.6.6) PLATFORMS arm64-darwin-21 diff --git a/docs/INSTALL.md b/docs/INSTALL.md index add822108..ccfc58107 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -9,7 +9,5 @@ or alternatively, in the [installation](installation/) folder: 3. [Preparing database storage](installation/3_database.md) 4. [Generating signing keys](installation/4_signingkey.md) 5. [Installing as a monolith](installation/5_install_monolith.md) -6. [Installing as a polylith](installation/6_install_polylith.md) -7. [Populate the configuration](installation/7_configuration.md) -8. [Starting the monolith](installation/8_starting_monolith.md) -9. [Starting the polylith](installation/9_starting_polylith.md) +6. [Populate the configuration](installation/7_configuration.md) +7. [Starting the monolith](installation/8_starting_monolith.md) diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index 56e19a8b4..46cfac220 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -22,12 +22,12 @@ curl --header "Authorization: Bearer " -X `Help & About` -> `Advanced` -> `Access Token`. Be aware that an `access_token` allows a client to perform actions as an user and should be kept **secret**. -The user must be an administrator in the `account_accounts` table in order to use these endpoints. +The user must be an administrator in the `userapi_accounts` table in order to use these endpoints. -Existing user accounts can be set to administrative accounts by changing `account_type` to `3` in `account_accounts` +Existing user accounts can be set to administrative accounts by changing `account_type` to `3` in `userapi_accounts` ``` -UPDATE account_accounts SET account_type = 3 WHERE localpart = '$localpart'; +UPDATE userapi_accounts SET account_type = 3 WHERE localpart = '$localpart'; ``` Where `$localpart` is the username only (e.g. `alice`). @@ -38,13 +38,18 @@ This endpoint will instruct Dendrite to part all local users from the given `roo in the URL. It may take some time to complete. A JSON body will be returned containing the user IDs of all affected users. +If the room has an alias set (e.g. is published), the room's ID will not be visible in the URL, but it can +be found as the room's "internal ID" in Element Web (Settings -> Advanced) + ## GET `/_dendrite/admin/evacuateUser/{userID}` This endpoint will instruct Dendrite to part the given local `userID` in the URL from all rooms which they are currently joined. A JSON body will be returned containing the room IDs of all affected rooms. -## POST `/_dendrite/admin/resetPassword/{localpart}` +## POST `/_dendrite/admin/resetPassword/{userID}` + +Reset the password of a local user. Request body format: @@ -54,9 +59,6 @@ Request body format: } ``` -Reset the password of a local user. The `localpart` is the username only, i.e. if -the full user ID is `@alice:domain.com` then the local part is `alice`. - ## GET `/_dendrite/admin/fulltext/reindex` This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. diff --git a/docs/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md similarity index 75% rename from docs/CONTRIBUTING.md rename to docs/development/CONTRIBUTING.md index 6ba05f46f..2aec4c363 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -9,6 +9,28 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. +## Contribution types + +We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it +is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: + - bug fixes + - security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests) + +We will accept the following with caveats: + - documentation fixes, provided they do not add additional instructions which can end up going out-of-date, + e.g example configs, shell commands. + - performance fixes, provided they do not add significantly more maintenance burden. + - additional functionality on existing features, provided the functionality is small and maintainable. + - additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations + - test-only changes, provided they help improve coverage or test tricky code. + +The following items are at risk of not being accepted: + - Configuration or CLI changes, particularly ones which increase the overall configuration surface. + +The following items are unlikely to be accepted into a main Dendrite release for now: + - New MSC implementations. + - New features which are not in the specification. + ## Sign off We require that everyone who contributes to the project signs off their contributions @@ -35,7 +57,7 @@ to do so for future contributions. ## Getting up and running -See the [Installation](installation) section for information on how to build an +See the [Installation](../installation) section for information on how to build an instance of Dendrite. You will likely need this in order to test your changes. ## Code style @@ -75,7 +97,20 @@ comment. Please avoid doing this if you can. We also have unit tests which we run via: ```bash -go test --race ./... +DENDRITE_TEST_SKIP_NODB=1 go test --race ./... +``` + +This only runs SQLite database tests. If you wish to execute Postgres tests as well, you'll either need to +have Postgres installed locally (`createdb` will be used) or have a remote/containerized Postgres instance +available. + +To configure the connection to a remote Postgres, you can use the following enviroment variables: + +```bash +POSTGRES_USER=postgres +POSTGERS_PASSWORD=yourPostgresPassword +POSTGRES_HOST=localhost +POSTGRES_DB=postgres # the superuser database to use ``` In general, we like submissions that come with tests. Anything that proves that the @@ -116,7 +151,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) +[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. diff --git a/docs/PROFILING.md b/docs/development/PROFILING.md similarity index 98% rename from docs/PROFILING.md rename to docs/development/PROFILING.md index f3b573472..57c37a900 100644 --- a/docs/PROFILING.md +++ b/docs/development/PROFILING.md @@ -15,7 +15,7 @@ Dendrite contains an embedded profiler called `pprof`, which is a part of the st To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g. ``` -PPROFLISTEN=localhost:65432 ./bin/dendrite-monolith-server ... +PPROFLISTEN=localhost:65432 ./bin/dendrite ... ``` If pprof has been enabled successfully, a log line at startup will show that pprof is listening: diff --git a/docs/coverage.md b/docs/development/coverage.md similarity index 77% rename from docs/coverage.md rename to docs/development/coverage.md index 7a3b7cb9e..c4a8a1174 100644 --- a/docs/coverage.md +++ b/docs/development/coverage.md @@ -14,8 +14,8 @@ index 8f0e209c..ad057e52 100644 $output->diag( "Starting monolith server" ); my @command = ( -- $self->{bindir} . '/dendrite-monolith-server', -+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", +- $self->{bindir} . '/dendrite', ++ $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", '--config', $self->{paths}{config}, '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, @@ -27,9 +27,9 @@ index f009332b..7ea79869 100755 echo >&2 "--- Building dendrite from source" cd /src mkdir -p $GOBIN --go install -v ./cmd/dendrite-monolith-server -+# go install -v ./cmd/dendrite-monolith-server -+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +-go install -v ./cmd/dendrite ++# go install -v ./cmd/dendrite ++go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite go install -v ./cmd/generate-keys cd - ``` @@ -57,22 +57,16 @@ github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% total: (statements) 53.7% ``` The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations, -which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required +which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, additional processing is required to remove packages which will never be tested and extension MSCs: ```bash # These commands are all similar but change which package paths are _removed_ from the output. -# For Postgres (monolith) +# For Postgres go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt -# For Postgres (polylith) -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt - -# For SQLite (monolith) +# For SQLite go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt - -# For SQLite (polylith) -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt ``` A total value can then be calculated using: diff --git a/docs/sytest.md b/docs/development/sytest.md similarity index 100% rename from docs/sytest.md rename to docs/development/sytest.md diff --git a/docs/tracing/opentracing.md b/docs/development/tracing/opentracing.md similarity index 100% rename from docs/tracing/opentracing.md rename to docs/development/tracing/opentracing.md diff --git a/docs/tracing/setup.md b/docs/development/tracing/setup.md similarity index 86% rename from docs/tracing/setup.md rename to docs/development/tracing/setup.md index 06f89bf85..cef1089e4 100644 --- a/docs/tracing/setup.md +++ b/docs/development/tracing/setup.md @@ -46,10 +46,10 @@ tracing: param: 1 ``` -then run the monolith server with `--api true` to use polylith components which do tracing spans: +then run the monolith server: ``` -./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml --api true +./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml ``` ## Checking traces diff --git a/docs/installation/10_starting_polylith.md b/docs/installation/10_starting_polylith.md deleted file mode 100644 index 0c2e2af2b..000000000 --- a/docs/installation/10_starting_polylith.md +++ /dev/null @@ -1,73 +0,0 @@ ---- -title: Starting the polylith -parent: Installation -has_toc: true -nav_order: 10 -permalink: /installation/start/polylith ---- - -# Starting the polylith - -Once you have completed all of the preparation and installation steps, -you can start your Dendrite polylith deployment by starting the various components -using the `dendrite-polylith-multi` personalities. - -## Start the reverse proxy - -Ensure that your reverse proxy is started and is proxying the correct -endpoints to the correct components. Software such as [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) can be used for this purpose. A [sample configuration -for NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf) -is provided. - -## Starting the components - -Each component must be started individually: - -### Client API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml clientapi -``` - -### Sync API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml syncapi -``` - -### Media API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml mediaapi -``` - -### Federation API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml federationapi -``` - -### Roomserver - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml roomserver -``` - -### Appservice API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml appservice -``` - -### User API - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml userapi -``` - -### Key server - -```bash -./dendrite-polylith-multi -config /path/to/dendrite.yaml keyserver -``` diff --git a/docs/installation/1_planning.md b/docs/installation/1_planning.md index 3aa5b4d85..36d90abda 100644 --- a/docs/installation/1_planning.md +++ b/docs/installation/1_planning.md @@ -16,12 +16,6 @@ Users can run Dendrite in one of two modes which dictate how these components ar server with generally low overhead. This mode dramatically simplifies deployment complexity and offers the best balance between performance and resource usage for low-to-mid volume deployments. -* **Polylith mode** runs all components in isolated processes. Components communicate through an external NATS - server and HTTP APIs, which incur considerable overhead. While this mode allows for more granular control of - resources dedicated toward individual processes, given the additional communications overhead, it is only - necessary for very large deployments. - -Given our current state of development, **we recommend monolith mode** for all deployments. ## Databases @@ -85,21 +79,15 @@ If using the PostgreSQL database engine, you should install PostgreSQL 12 or lat ### NATS Server -Monolith deployments come with a built-in [NATS Server](https://github.com/nats-io/nats-server) and -therefore do not need this to be manually installed. If you are planning a monolith installation, you +Dendrite comes with a built-in [NATS Server](https://github.com/nats-io/nats-server) and +therefore does not need this to be manually installed. If you are planning a monolith installation, you do not need to do anything. -Polylith deployments, however, currently need a standalone NATS Server installation with JetStream -enabled. - -To do so, follow the [NATS Server installation instructions](https://docs.nats.io/running-a-nats-service/introduction/installation) and then [start your NATS deployment](https://docs.nats.io/running-a-nats-service/introduction/running). JetStream must be enabled, either by passing the `-js` flag to `nats-server`, -or by specifying the `store_dir` option in the the `jetstream` configuration. ### Reverse proxy A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) is required for polylith deployments and is useful for monolith -deployments. Configuring those is not covered in this documentation, although sample configurations +[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring those is not covered in this documentation, although sample configurations for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided. diff --git a/docs/installation/3_build.md b/docs/installation/3_build.md index aed2080db..824c81d37 100644 --- a/docs/installation/3_build.md +++ b/docs/installation/3_build.md @@ -28,11 +28,11 @@ The resulting binaries will be placed in the `bin` subfolder. You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: ```sh -go install ./cmd/dendrite-monolith-server +go install ./cmd/dendrite ``` Alternatively, you can specify a custom path for the binary to be written to using `go build`: ```sh -go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server +go build -o /usr/local/bin/ ./cmd/dendrite ``` diff --git a/docs/installation/5_install_monolith.md b/docs/installation/5_install_monolith.md index 7de066cf7..901975a65 100644 --- a/docs/installation/5_install_monolith.md +++ b/docs/installation/5_install_monolith.md @@ -11,11 +11,11 @@ permalink: /installation/install/monolith You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: ```sh -go install ./cmd/dendrite-monolith-server +go install ./cmd/dendrite ``` Alternatively, you can specify a custom path for the binary to be written to using `go build`: ```sh -go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server +go build -o /usr/local/bin/ ./cmd/dendrite ``` diff --git a/docs/installation/6_install_polylith.md b/docs/installation/6_install_polylith.md deleted file mode 100644 index ec4a77628..000000000 --- a/docs/installation/6_install_polylith.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -title: Installing as a polylith -parent: Installation -has_toc: true -nav_order: 6 -permalink: /installation/install/polylith ---- - -# Installing as a polylith - -You can install the Dendrite polylith binary into `$GOPATH/bin` by using `go install`: - -```sh -go install ./cmd/dendrite-polylith-multi -``` - -Alternatively, you can specify a custom path for the binary to be written to using `go build`: - -```sh -go build -o /usr/local/bin/ ./cmd/dendrite-polylith-multi -``` - -The `dendrite-polylith-multi` binary is a "multi-personality" binary which can run as -any of the components depending on the supplied command line parameters. - -## Reverse proxy - -Polylith deployments require a reverse proxy in order to ensure that requests are -sent to the correct endpoint. You must ensure that a suitable reverse proxy is installed -and configured. - -Sample configurations are provided -for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy/polylith/Caddyfile) -and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf). \ No newline at end of file diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 19958c92f..5f123bfca 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -7,11 +7,10 @@ permalink: /installation/configuration # Configuring Dendrite -A YAML configuration file is used to configure Dendrite. Sample configuration files are +A YAML configuration file is used to configure Dendrite. A sample configuration file is present in the top level of the Dendrite repository: * [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) -* [`dendrite-sample.polylith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.polylith.yaml) You will need to duplicate the sample, calling it `dendrite.yaml` for example, and then tailor it to your installation. At a minimum, you will need to populate the following @@ -46,10 +45,9 @@ global: ## JetStream configuration Monolith deployments can use the built-in NATS Server rather than running a standalone -server. If you are building a polylith deployment, or you want to use a standalone NATS -Server anyway, you can also configure that too. +server. If you want to use a standalone NATS Server anyway, you can also configure that too. -### Built-in NATS Server (monolith only) +### Built-in NATS Server In the `global` section, under the `jetstream` key, ensure that no server addresses are configured and set a `storage_path` to a persistent folder on the filesystem: @@ -63,7 +61,7 @@ global: topic_prefix: Dendrite ``` -### Standalone NATS Server (monolith and polylith) +### Standalone NATS Server To use a standalone NATS Server instance, you will need to configure `addresses` field to point to the port that your NATS Server is listening on: @@ -86,7 +84,7 @@ one address in the `addresses` field. Configuring database connections varies based on the [database configuration](database) that you chose. -### Global connection pool (monolith with a single PostgreSQL database only) +### Global connection pool If you are running a monolith deployment and want to use a single connection pool to a single PostgreSQL database, then you must uncomment and configure the `database` section @@ -109,7 +107,7 @@ override the `global` database configuration. ### Per-component connections (all other configurations) -If you are building a polylith deployment, are using SQLite databases or separate PostgreSQL +If you are are using SQLite databases or separate PostgreSQL databases per component, then you must instead configure the `database` sections under each of the component blocks ,e.g. under the `app_service_api`, `federation_api`, `key_server`, `media_api`, `mscs`, `room_server`, `sync_api` and `user_api` blocks. diff --git a/docs/installation/9_starting_monolith.md b/docs/installation/9_starting_monolith.md index 124477e73..d7e8c0b8b 100644 --- a/docs/installation/9_starting_monolith.md +++ b/docs/installation/9_starting_monolith.md @@ -9,10 +9,10 @@ permalink: /installation/start/monolith # Starting the monolith Once you have completed all of the preparation and installation steps, -you can start your Dendrite monolith deployment by starting the `dendrite-monolith-server`: +you can start your Dendrite monolith deployment by starting `dendrite`: ```bash -./dendrite-monolith-server -config /path/to/dendrite.yaml +./dendrite -config /path/to/dendrite.yaml ``` By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses @@ -20,7 +20,7 @@ or ports that Dendrite listens on, you can use the `-http-bind-address` and `-https-bind-address` command line arguments: ```bash -./dendrite-monolith-server -config /path/to/dendrite.yaml \ +./dendrite -config /path/to/dendrite.yaml \ -http-bind-address 1.2.3.4:12345 \ -https-bind-address 1.2.3.4:54321 ``` diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 237120ffb..8a948a3fa 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -11,7 +11,7 @@ Type=simple User=dendrite Group=dendrite WorkingDirectory=/opt/dendrite/ -ExecStart=/opt/dendrite/bin/dendrite-monolith-server +ExecStart=/opt/dendrite/bin/dendrite Restart=always LimitNOFILE=65535 diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 50d0339e4..e4c0b2714 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,6 +18,7 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI + P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -30,7 +31,6 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error - PerformWakeupServers( ctx context.Context, request *PerformWakeupServersRequest, @@ -71,6 +71,29 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationAPI interface { + // Get the relay servers associated for the given server. + P2PQueryRelayServers( + ctx context.Context, + request *P2PQueryRelayServersRequest, + response *P2PQueryRelayServersResponse, + ) error + + // Add relay server associations to the given server. + P2PAddRelayServers( + ctx context.Context, + request *P2PAddRelayServersRequest, + response *P2PAddRelayServersResponse, + ) error + + // Remove relay server associations from the given server. + P2PRemoveRelayServers( + ctx context.Context, + request *P2PRemoveRelayServersRequest, + response *P2PRemoveRelayServersResponse, + ) error +} + // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError @@ -82,6 +105,7 @@ type KeyserverFederationAPI interface { // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { + P2PFederationClient gomatrixserverlib.KeyClient SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) @@ -110,6 +134,11 @@ type FederationClient interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationClient interface { + P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) +} + // FederationClientError is returned from FederationClient methods in the event of a problem. type FederationClientError struct { Err string @@ -233,3 +262,27 @@ type InputPublicKeysRequest struct { type InputPublicKeysResponse struct { } + +type P2PQueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type P2PQueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersResponse struct { +} + +type P2PRemoveRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PRemoveRelayServersResponse struct { +} diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 601257d4b..7d9df3d78 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -26,11 +26,11 @@ import ( "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // KeyChangeConsumer consumes events that originate in key server. diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index d16af6626..378b96ba0 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms msg := msgs[0] // Guaranteed to exist if onMessage is called receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) - // Only handle events we care about - if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek { + // Only handle events we care about, avoids unneeded unmarshalling + switch receivedType { + case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom: + default: return true } @@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return false } + case api.OutputTypePurgeRoom: + log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API") + if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API") + } else { + logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API") + } + default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -162,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ore.Event.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -195,7 +207,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == gomatrixserverlib.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) @@ -232,7 +244,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } @@ -472,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // At this point the missing events are neither the event itself nor are // they present in our local database. Our only option is to fetch them // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} + eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()} var eventResp api.QueryEventsByIDResponse if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { return nil, err diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 854251220..ec482659a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -17,20 +17,17 @@ package federationapi import ( "time" - "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/federationapi/inthttp" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -41,21 +38,14 @@ import ( "github.com/matrix-org/dendrite/federationapi/routing" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) { - inthttp.AddRoutes(intAPI, router) -} - // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.FederationKeyAPI, servers federationAPI.ServersInRoomProvider, ) { cfg := &base.Cfg.FederationAPI @@ -85,12 +75,9 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - cfg, + base, rsAPI, f, keyRing, - federation, userAPI, keyAPI, mscCfg, + federation, userAPI, mscCfg, servers, producer, ) } @@ -116,7 +103,10 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) + stats := statistics.NewStatistics( + federationDB, + cfg.FederationMaxRetries+1, + cfg.P2PFederationRetriesUntilAssumedOffline+1) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index cc03cdece..bb6ee8935 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -77,8 +77,8 @@ func TestMain(m *testing.M) { // API to work. cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name) cfg.Global.PrivateKey = testPriv @@ -109,7 +109,7 @@ func TestMain(m *testing.M) { ) // Finally, build the server key APIs. - sbase := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + sbase := base.NewBaseDendrite(cfg, base.DisableMetrics) s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, nil, true) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 68a06a033..57d4b9644 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -17,13 +17,13 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type fedRoomserverAPI struct { @@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the // federationapi is incorrectly waiting for an output room event to arrive to update the joined // hosts table. - key := keyapi.DeviceMessage{ - Type: keyapi.TypeDeviceKeyUpdate, - DeviceKeys: &keyapi.DeviceKeys{ + key := userapi.DeviceMessage{ + Type: userapi.TypeDeviceKeyUpdate, + DeviceKeys: &userapi.DeviceKeys{ UserID: joiningUser.ID, DeviceID: "MY_DEVICE", DisplayName: "BLARGLE", @@ -266,19 +266,19 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { _, privKey, _ := ed25519.GenerateKey(nil) cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) cfg.Global.KeyID = gomatrixserverlib.KeyID("ed25519:auto") cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - base := base.NewBaseDendrite(cfg, "Monolith") + b := base.NewBaseDendrite(cfg, base.DisableMetrics) keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) - baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil) + baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 14056eafc..99773a750 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -109,13 +109,14 @@ func NewFederationInternalAPI( func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) - until, blacklisted := stats.BackoffInfo() - if blacklisted { + if stats.Blacklisted() { return stats, &api.FederationClientError{ Blacklisted: true, } } + now := time.Now() + until := stats.BackoffInfo() if until != nil && now.Before(*until) { return stats, &api.FederationClientError{ RetryAfter: time.Until(*until), @@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( RetryAfter: retryAfter, } } - stats.Success() + stats.Success(statistics.SendDirect) return res, nil } @@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) - if _, blacklisted := stats.BackoffInfo(); blacklisted { + if blacklisted := stats.Blacklisted(); blacklisted { return stats, &api.FederationClientError{ Err: fmt.Sprintf("server %q is blacklisted", s), Blacklisted: true, diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go new file mode 100644 index 000000000..49137e2d8 --- /dev/null +++ b/federationapi/internal/federationclient_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 +) + +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { + t.queryKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespQueryKeys{}, nil +} + +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { + t.claimKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespClaimKeys{}, nil +} + +func TestFederationClientQueryKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysFailure(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{shouldFail: true} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientClaimKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.claimKeysCalled) +} + +func TestFederationClientClaimKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.claimKeysCalled) +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d86d07e03..dadb2b2b3 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) (err error) { + if !r.shouldAttemptDirectFederation(request.ServerName) { + return fmt.Errorf("relay servers have no meaningful response for directory lookup.") + } + dir, err := r.federation.LookupRoomAlias( ctx, r.cfg.Matrix.ServerName, @@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( } response.RoomID = dir.RoomID response.ServerNames = dir.Servers - r.statistics.ForServer(request.ServerName).Success() + r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect) return nil } @@ -116,8 +121,6 @@ func (r *FederationInternalAPI) PerformJoin( var httpErr gomatrix.HTTPError if ok := errors.As(lastErr, &httpErr); ok { httpErr.Message = string(httpErr.Contents) - // Clear the wrapped error, else serialising to JSON (in polylith mode) will fail - httpErr.WrappedError = nil response.LastError = &httpErr } else { response.LastError = &gomatrix.HTTPError{ @@ -144,6 +147,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for join.") + } + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) if err != nil { return err @@ -164,7 +171,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.MakeJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" @@ -219,7 +226,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.SendJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // If the remote server returned an event in the "event" key of // the send_join request then we should use that instead. It may @@ -382,8 +389,6 @@ func (r *FederationInternalAPI) PerformOutboundPeek( var httpErr gomatrix.HTTPError if ok := errors.As(lastErr, &httpErr); ok { httpErr.Message = string(httpErr.Contents) - // Clear the wrapped error, else serialising to JSON (in polylith mode) will fail - httpErr.WrappedError = nil response.LastError = &httpErr } else { response.LastError = &gomatrix.HTTPError{ @@ -407,6 +412,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( serverName gomatrixserverlib.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for outbound peek.") + } + // create a unique ID for this peek. // for now we just use the room ID again. In future, if we ever // support concurrent peeks to the same room with different filters @@ -446,7 +455,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.Peek: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Work out if we support the room version that has been supplied in // the peek response. @@ -516,6 +525,10 @@ func (r *FederationInternalAPI) PerformLeave( // Try each server that we were provided until we land on one that // successfully completes the make-leave send-leave dance. for _, serverName := range request.ServerNames { + if !r.shouldAttemptDirectFederation(serverName) { + continue + } + // Try to perform a make_leave using the information supplied in the // request. respMakeLeave, err := r.federation.MakeLeave( @@ -585,7 +598,7 @@ func (r *FederationInternalAPI) PerformLeave( continue } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) return nil } @@ -616,6 +629,12 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } + // TODO (devon): This should be allowed via a relay. Currently only transactions + // can be sent to relays. Would need to extend relays to handle invites. + if !r.shouldAttemptDirectFederation(destination) { + return fmt.Errorf("relay servers have no meaningful response for invite.") + } + logrus.WithFields(logrus.Fields{ "event_id": request.Event.EventID(), "user_id": *request.Event.StateKey(), @@ -682,12 +701,8 @@ func (r *FederationInternalAPI) PerformWakeupServers( func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - // Check the statistics cache for the blacklist status to prevent hitting - // the database unnecessarily. - if r.queues.IsServerBlacklisted(srv) { - _ = r.db.RemoveServerFromBlacklist(srv) - } - r.queues.RetryServer(srv) + wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() + r.queues.RetryServer(srv, wasBlacklisted) } } @@ -719,7 +734,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -802,3 +819,61 @@ func federatedAuthProvider( return returning, nil } } + +// P2PQueryRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + logrus.Infof("Getting relay servers for: %s", request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) + if err != nil { + return err + } + + response.RelayServers = relayServers + return nil +} + +// P2PAddRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PAddRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +// P2PRemoveRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PRemoveRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { + var shouldRelay bool + stats := r.statistics.ForServer(destination) + if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { + shouldRelay = true + } + + return !shouldRelay +} diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go new file mode 100644 index 000000000..e6e366f99 --- /dev/null +++ b/federationapi/internal/perform_test.go @@ -0,0 +1,231 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + api.FederationClient + queryKeysCalled bool + claimKeysCalled bool + shouldFail bool +} + +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return gomatrixserverlib.RespDirectory{}, nil +} + +func TestPerformWakeupServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.AddServerToBlacklist(server) + testDB.SetServerAssumedOffline(context.Background(), server) + blacklisted, err := testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.True(t, blacklisted) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.True(t, offline) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{server}, + } + res := api.PerformWakeupServersResponse{} + err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) + assert.NoError(t, err) + + blacklisted, err = testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.False(t, blacklisted) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.False(t, offline) +} + +func TestQueryRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PQueryRelayServersRequest{ + Server: server, + } + res := api.P2PQueryRelayServersResponse{} + err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} + +func TestRemoveRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PRemoveRelayServersRequest{ + Server: server, + RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + } + res := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) + assert.NoError(t, err) + assert.Equal(t, 1, len(finalRelays)) + assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) +} + +func TestPerformDirectoryLookup(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: "server", + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.NoError(t, err) +} + +func TestPerformDirectoryLookupRelaying(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: server, + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: server, + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.Error(t, err) +} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go deleted file mode 100644 index 6eefdc7cd..000000000 --- a/federationapi/inthttp/client.go +++ /dev/null @@ -1,512 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" -) - -// HTTP paths for the internal HTTP API -const ( - FederationAPIQueryJoinedHostServerNamesInRoomPath = "/federationapi/queryJoinedHostServerNamesInRoom" - FederationAPIQueryServerKeysPath = "/federationapi/queryServerKeys" - - FederationAPIPerformDirectoryLookupRequestPath = "/federationapi/performDirectoryLookup" - FederationAPIPerformJoinRequestPath = "/federationapi/performJoinRequest" - FederationAPIPerformLeaveRequestPath = "/federationapi/performLeaveRequest" - FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" - FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" - FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" - FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" - - FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" - FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" - FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" - FederationAPIBackfillPath = "/federationapi/client/backfill" - FederationAPILookupStatePath = "/federationapi/client/lookupState" - FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" - FederationAPILookupMissingEventsPath = "/federationapi/client/lookupMissingEvents" - FederationAPIGetEventPath = "/federationapi/client/getEvent" - FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" - FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" - FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" - FederationAPIGetEventAuthPath = "/federationapi/client/getEventAuth" - - FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey" - FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey" -) - -// NewFederationAPIClient creates a FederationInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.FederationInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is ") - } - return &httpFederationInternalAPI{ - federationAPIURL: federationSenderURL, - httpClient: httpClient, - cache: cache, - }, nil -} - -type httpFederationInternalAPI struct { - federationAPIURL string - httpClient *http.Client - cache caching.ServerKeyCache -} - -// Handle an instruction to make_leave & send_leave with a remote server. -func (h *httpFederationInternalAPI) PerformLeave( - ctx context.Context, - request *api.PerformLeaveRequest, - response *api.PerformLeaveResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle sending an invite to a remote server. -func (h *httpFederationInternalAPI) PerformInvite( - ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle starting a peek on a remote server. -func (h *httpFederationInternalAPI) PerformOutboundPeek( - ctx context.Context, - request *api.PerformOutboundPeekRequest, - response *api.PerformOutboundPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI -func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *api.QueryJoinedHostServerNamesInRoomRequest, - response *api.QueryJoinedHostServerNamesInRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to make_join & send_join with a remote server. -func (h *httpFederationInternalAPI) PerformJoin( - ctx context.Context, - request *api.PerformJoinRequest, - response *api.PerformJoinResponse, -) { - if err := httputil.CallInternalRPCAPI( - "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath, - h.httpClient, ctx, request, response, - ); err != nil { - response.LastError = &gomatrix.HTTPError{ - Message: err.Error(), - Code: 0, - WrappedError: err, - } - } -} - -// Handle an instruction to make_join & send_join with a remote server. -func (h *httpFederationInternalAPI) PerformDirectoryLookup( - ctx context.Context, - request *api.PerformDirectoryLookupRequest, - response *api.PerformDirectoryLookupResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. -func (h *httpFederationInternalAPI) PerformBroadcastEDU( - ctx context.Context, - request *api.PerformBroadcastEDURequest, - response *api.PerformBroadcastEDUResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath, - h.httpClient, ctx, request, response, - ) -} - -// Handle an instruction to remove the respective servers from being blacklisted. -func (h *httpFederationInternalAPI) PerformWakeupServers( - ctx context.Context, - request *api.PerformWakeupServersRequest, - response *api.PerformWakeupServersResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers, - h.httpClient, ctx, request, response, - ) -} - -type getUserDevices struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - UserID string -} - -func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, -) (gomatrixserverlib.RespUserDevices, error) { - return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( - "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, - ctx, &getUserDevices{ - S: s, - Origin: origin, - UserID: userID, - }, - ) -} - -type claimKeys struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - OneTimeKeys map[string]map[string]string -} - -func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, -) (gomatrixserverlib.RespClaimKeys, error) { - return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( - "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, - ctx, &claimKeys{ - S: s, - Origin: origin, - OneTimeKeys: oneTimeKeys, - }, - ) -} - -type queryKeys struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - Keys map[string][]string -} - -func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, -) (gomatrixserverlib.RespQueryKeys, error) { - return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( - "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, - ctx, &queryKeys{ - S: s, - Origin: origin, - Keys: keys, - }, - ) -} - -type backfill struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - Limit int - EventIDs []string -} - -func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, -) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( - "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, - ctx, &backfill{ - S: s, - Origin: origin, - RoomID: roomID, - Limit: limit, - EventIDs: eventIDs, - }, - ) -} - -type lookupState struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - EventID string - RoomVersion gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, -) (gomatrixserverlib.RespState, error) { - return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( - "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, - ctx, &lookupState{ - S: s, - Origin: origin, - RoomID: roomID, - EventID: eventID, - RoomVersion: roomVersion, - }, - ) -} - -type lookupStateIDs struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - EventID string -} - -func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, -) (gomatrixserverlib.RespStateIDs, error) { - return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( - "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, - ctx, &lookupStateIDs{ - S: s, - Origin: origin, - RoomID: roomID, - EventID: eventID, - }, - ) -} - -type lookupMissingEvents struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomID string - Missing gomatrixserverlib.MissingEvents - RoomVersion gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, - missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.RespMissingEvents, err error) { - return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( - "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, - ctx, &lookupMissingEvents{ - S: s, - Origin: origin, - RoomID: roomID, - Missing: missing, - RoomVersion: roomVersion, - }, - ) -} - -type getEvent struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - EventID string -} - -func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, -) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( - "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, - ctx, &getEvent{ - S: s, - Origin: origin, - EventID: eventID, - }, - ) -} - -type getEventAuth struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - RoomVersion gomatrixserverlib.RoomVersion - RoomID string - EventID string -} - -func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, origin, s gomatrixserverlib.ServerName, - roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, -) (gomatrixserverlib.RespEventAuth, error) { - return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( - "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, - ctx, &getEventAuth{ - S: s, - Origin: origin, - RoomVersion: roomVersion, - RoomID: roomID, - EventID: eventID, - }, - ) -} - -func (h *httpFederationInternalAPI) QueryServerKeys( - ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath, - h.httpClient, ctx, req, res, - ) -} - -type lookupServerKeys struct { - S gomatrixserverlib.ServerName - KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp -} - -func (h *httpFederationInternalAPI) LookupServerKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) ([]gomatrixserverlib.ServerKeys, error) { - return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError]( - "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient, - ctx, &lookupServerKeys{ - S: s, - KeyRequests: keyRequests, - }, - ) -} - -type eventRelationships struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2836EventRelationshipsRequest - RoomVer gomatrixserverlib.RoomVersion -} - -func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, - roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( - "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, - ctx, &eventRelationships{ - S: s, - Origin: origin, - Req: r, - RoomVer: roomVersion, - }, - ) -} - -type spacesReq struct { - S gomatrixserverlib.ServerName - Origin gomatrixserverlib.ServerName - SuggestedOnly bool - RoomID string -} - -func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, -) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( - "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, - ctx, &spacesReq{ - S: dst, - Origin: origin, - SuggestedOnly: suggestedOnly, - RoomID: roomID, - }, - ) -} - -func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { - // This is a bit of a cheat - we tell gomatrixserverlib that this API is - // both the key database and the key fetcher. While this does have the - // rather unfortunate effect of preventing gomatrixserverlib from handling - // key fetchers directly, we can at least reimplement this behaviour on - // the other end of the API. - return &gomatrixserverlib.KeyRing{ - KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{}, - } -} - -func (s *httpFederationInternalAPI) FetcherName() string { - return "httpServerKeyInternalAPI" -} - -func (s *httpFederationInternalAPI) StoreKeys( - _ context.Context, - results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, -) error { - // Run in a background context - we don't want to stop this work just - // because the caller gives up waiting. - ctx := context.Background() - request := api.InputPublicKeysRequest{ - Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), - } - response := api.InputPublicKeysResponse{} - for req, res := range results { - request.Keys[req] = res - s.cache.StoreServerKey(req, res) - } - return s.InputPublicKeys(ctx, &request, &response) -} - -func (s *httpFederationInternalAPI) FetchKeys( - _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - // Run in a background context - we don't want to stop this work just - // because the caller gives up waiting. - ctx := context.Background() - result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) - request := api.QueryPublicKeysRequest{ - Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp), - } - response := api.QueryPublicKeysResponse{ - Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), - } - for req, ts := range requests { - if res, ok := s.cache.GetServerKey(req, ts); ok { - result[req] = res - continue - } - request.Requests[req] = ts - } - err := s.QueryPublicKeys(ctx, &request, &response) - if err != nil { - return nil, err - } - for req, res := range response.Results { - result[req] = res - s.cache.StoreServerKey(req, res) - } - return result, nil -} - -func (h *httpFederationInternalAPI) InputPublicKeys( - ctx context.Context, - request *api.InputPublicKeysRequest, - response *api.InputPublicKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpFederationInternalAPI) QueryPublicKeys( - ctx context.Context, - request *api.QueryPublicKeysRequest, - response *api.QueryPublicKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go deleted file mode 100644 index 21a070392..000000000 --- a/federationapi/inthttp/server.go +++ /dev/null @@ -1,257 +0,0 @@ -package inthttp - -import ( - "context" - "encoding/json" - "net/http" - "net/url" - - "github.com/gorilla/mux" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/httputil" -) - -// AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux. -// nolint:gocyclo -func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle( - FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom), - ) - - internalAPIMux.Handle( - FederationAPIPerformInviteRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite), - ) - - internalAPIMux.Handle( - FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave), - ) - - internalAPIMux.Handle( - FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup), - ) - - internalAPIMux.Handle( - FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), - ) - - internalAPIMux.Handle( - FederationAPIPerformWakeupServers, - httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), - ) - - internalAPIMux.Handle( - FederationAPIPerformJoinRequestPath, - httputil.MakeInternalRPCAPI( - "FederationAPIPerformJoinRequest", - func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { - intAPI.PerformJoin(ctx, req, res) - return nil - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetUserDevicesPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetUserDevices", - func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIClaimKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPIClaimKeys", - func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIQueryKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPIQueryKeys", - func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIBackfillPath, - httputil.MakeInternalProxyAPI( - "FederationAPIBackfill", - func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupStatePath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupState", - func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupStateIDsPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupStateIDs", - func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPILookupMissingEventsPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupMissingEvents", - func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetEventPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetEvent", - func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIGetEventAuthPath, - httputil.MakeInternalProxyAPI( - "FederationAPIGetEventAuth", - func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIQueryServerKeysPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys), - ) - - internalAPIMux.Handle( - FederationAPILookupServerKeysPath, - httputil.MakeInternalProxyAPI( - "FederationAPILookupServerKeys", - func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { - res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPIEventRelationshipsPath, - httputil.MakeInternalProxyAPI( - "FederationAPIMSC2836EventRelationships", - func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) - return &res, federationClientError(err) - }, - ), - ) - - internalAPIMux.Handle( - FederationAPISpacesSummaryPath, - httputil.MakeInternalProxyAPI( - "FederationAPIMSC2946SpacesSummary", - func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) - return &res, federationClientError(err) - }, - ), - ) - - // TODO: Look at this shape - internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse { - request := api.QueryPublicKeysRequest{} - response := api.QueryPublicKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - keys, err := intAPI.FetchKeys(req.Context(), request.Requests) - if err != nil { - return util.ErrorResponse(err) - } - response.Results = keys - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - - // TODO: Look at this shape - internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse { - request := api.InputPublicKeysRequest{} - response := api.InputPublicKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.StoreKeys(req.Context(), request.Keys); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} - -func federationClientError(err error) error { - switch ferr := err.(type) { - case nil: - return nil - case api.FederationClientError: - return &ferr - case *api.FederationClientError: - return ferr - case gomatrix.HTTPError: - return &api.FederationClientError{ - Code: ferr.Code, - } - case *url.Error: // e.g. certificate error, unable to connect - return &api.FederationClientError{ - Err: ferr.Error(), - Code: 400, - } - default: - // We don't know what exactly failed, but we probably don't - // want to retry the request immediately in the device list updater - return &api.FederationClientError{ - Err: err.Error(), - Code: 400, - } - } -} diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 7cce13a7d..6bcfafa39 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -41,7 +41,7 @@ type SyncAPIProducer struct { TopicSigningKeyUpdate string JetStream nats.JetStreamContext Config *config.FederationAPI - UserAPI userapi.UserInternalAPI + UserAPI userapi.FederationUserAPI } func (p *SyncAPIProducer) SendReceipt( diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4a87fe99..12e6db9fa 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, + pdu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, + edu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.dbReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.dbReceipt.String()] = struct{}{} } overflowed := false @@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. _, blacklisted := oq.statistics.Failure() @@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() { return } } else { - oq.handleTransactionSuccess(pduCount, eduCount) + oq.handleTransactionSuccess(pduCount, eduCount, sendMethod) } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. -// Returns an error if the transaction wasn't sent. +// Returns an error if the transaction wasn't sent. And whether the success +// was to a relay server or not. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) error { +) (err error, sendMethod statistics.SendMethod) { // Create the transaction. t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) @@ -407,7 +408,52 @@ func (oq *destinationQueue) nextTransaction( // Try to send the transaction to the destination server. ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() - _, err := oq.client.SendTransaction(ctx, t) + + relayServers := oq.statistics.KnownRelayServers() + hasRelayServers := len(relayServers) > 0 + shouldSendToRelays := oq.statistics.AssumedOffline() && hasRelayServers + if !shouldSendToRelays { + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + } else { + // Try sending directly to the destination first in case they came back online. + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + if err != nil { + // The destination is still offline, try sending to relays. + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + + // TODO : what about if the dest comes back online but can't see their relay? + // How do I sync with the dest in that case? + // Should change the database to have a "relay success" flag on events and if + // I see the node back online, maybe directly send through the backlog of events + // with "relay success"... could lead to duplicate events, but only those that + // I sent. And will lead to a much more consistent experience. + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } + } switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. @@ -427,7 +473,7 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return nil + return nil, sendMethod case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. @@ -437,13 +483,13 @@ func (oq *destinationQueue) nextTransaction( // to a 400-ish error code := errResponse.Code logrus.Debug("Transaction failed with HTTP", code) - return err + return err, sendMethod default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return err + return err, sendMethod } } @@ -453,7 +499,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -474,8 +520,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. @@ -487,7 +533,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.dbReceipt) } // Do the same for pending EDUS in the queue. @@ -497,7 +543,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.dbReceipt) } return t, pduReceipts, eduReceipts @@ -530,10 +576,11 @@ func (oq *destinationQueue) blacklistDestination() { // handleTransactionSuccess updates the cached event queues as well as the success and // backoff information for this server. -func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() + + oq.statistics.Success(sendMethod) oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 75b1b36be..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { @@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU( return nil } -// IsServerBlacklisted returns whether or not the provided server is currently -// blacklisted. -func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { - return oqs.statistics.ForServer(srv).Blacklisted() -} - // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { if oqs.disabled { return } - serverStatistics := oqs.statistics.ForServer(srv) - forceWakeup := serverStatistics.Blacklisted() - serverStatistics.RemoveBlacklist() - serverStatistics.ClearBackoff() - if queue := oqs.getQueue(srv); queue != nil { - queue.wakeQueueIfEventsPending(forceWakeup) + queue.wakeQueueIfEventsPending(wasBlacklisted) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index b2ec4b836..bccfb3428 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "sync" "testing" "time" @@ -26,13 +25,11 @@ import ( "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" @@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := createDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} @@ -65,242 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -type fakeDatabase struct { - storage.Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -var nidMutex sync.Mutex -var nid = int64(0) - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - type stubFederationRoomServerAPI struct { rsapi.FederationRoomserverAPI } @@ -312,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont type stubFederationClient struct { api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 + shouldTxSucceed bool + shouldTxRelaySucceed bool + txCount atomic.Uint32 + txRelayCount atomic.Uint32 } func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { @@ -326,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return gomatrixserverlib.RespSend{}, result } +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { + var result error + if !f.shouldTxRelaySucceed { + result = fmt.Errorf("relay transaction failed") + } + + f.txRelayCount.Add(1) + return gomatrixserverlib.EmptyResp{}, result +} + func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` @@ -341,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), + shouldTxSucceed: shouldTxSucceed, + shouldTxRelaySucceed: shouldTxRelaySucceed, + txCount: *atomic.NewUint32(0), + txRelayCount: *atomic.NewUint32(0), } rs := &stubFederationRoomServerAPI{} - stats := statistics.NewStatistics(db, failuresUntilBlacklist) + + stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{ { KeyID: "ed21019:auto", @@ -366,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -395,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -424,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -454,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -484,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -535,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -586,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -618,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -650,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -684,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -718,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -752,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -769,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -803,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -823,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -867,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -911,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -962,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -1000,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1045,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) assert.NoError(t, dbErrPDU) @@ -1060,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} + +func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index ce8b06b70..871d26cd4 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -17,7 +17,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -26,11 +26,11 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - keyAPI keyapi.FederationKeyAPI, + keyAPI api.FederationKeyAPI, userID string, ) util.JSONResponse { - var res keyapi.QueryDeviceMessagesResponse - if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + var res api.QueryDeviceMessagesResponse + if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{ UserID: userID, }, &res); err != nil { return util.ErrorResponse(err) @@ -40,12 +40,12 @@ func GetUserDevices( return jsonerror.InternalServerError() } - sigReq := &keyapi.QuerySignaturesRequest{ + sigReq := &api.QuerySignaturesRequest{ TargetIDs: map[string][]gomatrixserverlib.KeyID{ userID: {}, }, } - sigRes := &keyapi.QuerySignaturesResponse{} + sigRes := &api.QuerySignaturesResponse{} for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 868785a9b..2f1f3baf6 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -36,7 +36,7 @@ func GetEventAuth( return *err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 6168912bd..b41292415 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -20,10 +20,11 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" ) // GetEvent returns the requested event @@ -38,7 +39,9 @@ func GetEvent( if err != nil { return *err } - event, err := fetchEvent(ctx, rsAPI, eventID) + // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, + // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. + event, err := fetchEvent(ctx, rsAPI, "", eventID) if err != nil { return *err } @@ -60,21 +63,13 @@ func allowedToSeeEvent( rsAPI api.FederationRoomserverAPI, eventID string, ) *util.JSONResponse { - var authResponse api.QueryServerAllowedToSeeEventResponse - err := rsAPI.QueryServerAllowedToSeeEvent( - ctx, - &api.QueryServerAllowedToSeeEventRequest{ - EventID: eventID, - ServerName: origin, - }, - &authResponse, - ) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) if err != nil { resErr := util.ErrorResponse(err) return &resErr } - if !authResponse.AllowedToSeeEvent { + if !allowed { resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event") return &resErr } @@ -83,11 +78,11 @@ func allowedToSeeEvent( } // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. -func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { +func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse err := rsAPI.QueryEventsByID( ctx, - &api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, + &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID}, &eventsResponse, ) if err != nil { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index dc262cfde..2885cc916 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -22,8 +22,8 @@ import ( clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go new file mode 100644 index 000000000..3b9d576bf --- /dev/null +++ b/federationapi/routing/profile_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeUserAPI struct { + userAPI.FederationUserAPI +} + +func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { + return nil +} + +func TestHandleQueryProfile(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go new file mode 100644 index 000000000..d839a16b8 --- /dev/null +++ b/federationapi/routing/query_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedclient "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeFedClient struct { + fedclient.FederationClient +} + +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return +} + +func TestHandleQueryDirectory(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 0a3ab7a88..324740ddc 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -29,9 +29,9 @@ import ( "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -40,6 +40,12 @@ import ( "github.com/sirupsen/logrus" ) +const ( + SendRouteName = "Send" + QueryDirectoryRouteName = "QueryDirectory" + QueryProfileRouteName = "QueryProfile" +) + // Setup registers HTTP handlers with the given ServeMux. // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled @@ -49,21 +55,26 @@ import ( // applied: // nolint: gocyclo func Setup( - fedMux, keyMux, wkMux *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, - keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) { - prometheus.MustRegister( - pduCountTotal, eduCountTotal, - ) + fedMux := base.PublicFederationAPIMux + keyMux := base.PublicKeyAPIMux + wkMux := base.PublicWellKnownAPIMux + cfg := &base.Cfg.FederationAPI + + if base.EnableMetrics { + prometheus.MustRegister( + internal.PDUCountTotal, internal.EDUCountTotal, + ) + } v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() @@ -128,10 +139,10 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, + cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, ) }, - )).Methods(http.MethodPut, http.MethodOptions) + )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -241,7 +252,7 @@ func Setup( httpReq, federation, cfg, rsAPI, fsAPI, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryDirectoryRouteName) v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -250,13 +261,13 @@ func Setup( httpReq, userAPI, cfg, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryProfileRouteName) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, keyAPI, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) @@ -481,14 +492,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a146d85bd..82651719f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,26 +17,20 @@ package routing import ( "context" "encoding/json" - "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" - syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" ) const ( @@ -56,26 +50,6 @@ const ( MetricsWorkMissingPrevEvents = "missing_prev_events" ) -var ( - pduCountTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_pdus", - Help: "Number of incoming PDUs from remote servers with labels for success", - }, - []string{"status"}, // 'success' or 'total' - ) - eduCountTotal = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_edus", - Help: "Number of incoming EDUs from remote servers", - }, - ) -) - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} @@ -85,7 +59,7 @@ func Send( txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + keyAPI userAPI.FederationUserAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, mu *internal.MutexByRoom, @@ -123,18 +97,6 @@ func Send( defer close(ch) defer inFlightTxnsPerOrigin.Delete(index) - t := txnReq{ - rsAPI: rsAPI, - keys: keys, - ourServerName: cfg.Matrix.ServerName, - federation: federation, - servers: servers, - keyAPI: keyAPI, - roomsMu: mu, - producer: producer, - inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound, - } - var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` @@ -155,16 +117,23 @@ func Send( } } - // TODO: Really we should have a function to convert FederationRequest to txnReq - t.PDUs = txnEvents.PDUs - t.EDUs = txnEvents.EDUs - t.Origin = request.Origin() - t.TransactionID = txnID - t.Destination = cfg.Matrix.ServerName + t := internal.NewTxnReq( + rsAPI, + keyAPI, + cfg.Matrix.ServerName, + keys, + mu, + producer, + cfg.Matrix.Presence.EnableInbound, + txnEvents.PDUs, + txnEvents.EDUs, + request.Origin(), + txnID, + cfg.Matrix.ServerName) util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.ProcessTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -181,283 +150,3 @@ func Send( ch <- res return res } - -type txnReq struct { - gomatrixserverlib.Transaction - rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI - ourServerName gomatrixserverlib.ServerName - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - roomsMu *internal.MutexByRoom - servers federationAPI.ServersInRoomProvider - producer *producers.SyncAPIProducer - inboundPresenceEnabled bool -} - -// A subset of FederationClient functionality that txn requires. Useful for testing. -type txnFederationClient interface { - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, - ) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - - results := make(map[string]gomatrixserverlib.PDUResult) - roomVersions := make(map[string]gomatrixserverlib.RoomVersion) - getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { - if v, ok := roomVersions[roomID]; ok { - return v - } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) - return "" - } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion - } - - for _, pdu := range t.PDUs { - pduCountTotal.WithLabelValues("total").Inc() - var header struct { - RoomID string `json:"room_id"` - } - if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) - if err != nil { - if _, ok := err.(gomatrixserverlib.BadJSONError); ok { - // Room version 6 states that homeservers should strictly enforce canonical JSON - // on PDUs. - // - // This enforces that the entire transaction is rejected if a single bad PDU is - // sent. It is unclear if this is the correct behaviour or not. - // - // See https://github.com/matrix-org/synapse/issues/7543 - return nil, &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), - } - } - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) - continue - } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { - continue - } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: "Forbidden by server ACLs", - } - continue - } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - if err = api.SendEvents( - ctx, - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), - }, - t.Destination, - t.Origin, - api.DoNotSendToOtherServers, - nil, - true, - ); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - results[event.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - } - - wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -// nolint:gocyclo -func (t *txnReq) processEDUs(ctx context.Context) { - for _, e := range t.EDUs { - eduCountTotal.Inc() - switch e.Type { - case gomatrixserverlib.MTyping: - // https://matrix.org/docs/spec/server_server/latest#typing-notifications - var typingPayload struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - Typing bool `json:"typing"` - } - if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") - } - case gomatrixserverlib.MDirectToDevice: - // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema - var directPayload gomatrixserverlib.ToDeviceMessage - if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - for userID, byUser := range directPayload.Messages { - for deviceID, message := range byUser { - // TODO: check that the user and the device actually exist here - if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": directPayload.Sender, - "user_id": userID, - "device_id": deviceID, - }).Error("Failed to send send-to-device event to JetStream") - } - } - } - case gomatrixserverlib.MDeviceListUpdate: - if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") - } - case gomatrixserverlib.MReceipt: - // https://matrix.org/docs/spec/server_server/r0.1.4#receipts - payload := map[string]types.FederationReceiptMRead{} - - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") - continue - } - - for roomID, receipt := range payload { - for userID, mread := range receipt.User { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") - continue - } - if t.Origin != domain { - util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) - continue - } - if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": t.Origin, - "user_id": userID, - "room_id": roomID, - "events": mread.EventIDs, - }).Error("Failed to send receipt event to JetStream") - continue - } - } - } - case types.MSigningKeyUpdate: - if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - logrus.WithError(err).Errorf("Failed to process signing key update") - } - case gomatrixserverlib.MPresence: - if t.inboundPresenceEnabled { - if err := t.processPresence(ctx, e); err != nil { - logrus.WithError(err).Errorf("Failed to process presence update") - } - } - default: - util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") - } - } -} - -// processPresence handles m.receipt events -func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { - payload := types.Presence{} - if err := json.Unmarshal(e.Content, &payload); err != nil { - return err - } - for _, content := range payload.Push { - if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - presence, ok := syncTypes.PresenceFromString(content.Presence) - if !ok { - continue - } - if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { - return err - } - } - return nil -} - -// processReceiptEvent sends receipt events to JetStream -func (t *txnReq) processReceiptEvent(ctx context.Context, - userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, - eventIDs []string, -) error { - if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { - return nil - } else if serverName == t.ourServerName { - return nil - } else if serverName != t.Origin { - return nil - } - // store every event - for _, eventID := range eventIDs { - if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { - return fmt.Errorf("unable to set receipt event: %w", err) - } - } - - return nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b8bfe0221..eed4e7e69 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -1,552 +1,87 @@ -package routing +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test import ( - "context" + "encoding/hex" "encoding/json" - "fmt" + "net/http/httptest" "testing" - "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") ) -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) -) +type sendContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) +func TestHandleSend(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + content := sendContent{} + err := req.SetContent(content) if err != nil { - panic("cannot load test data: " + err.Error()) + t.Fatalf("Error: %s", err.Error()) } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) } - } + vars := map[string]string{"txnID": "1234"} + w := httptest.NewRecorder() + httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + assert.Equal(t, 200, res.StatusCode) + }) } - -type testRoomserverAPI struct { - api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse -} - -func (t *testRoomserverAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) - for _, ire := range request.InputRoomEvents { - fmt.Println("InputRoomEvents: ", ire.Event.EventID()) - } - return nil -} - -// Query the latest events and state for a room from the room server. -func (t *testRoomserverAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - r := t.queryLatestEventsAndState(request) - response.RoomExists = r.RoomExists - response.RoomVersion = testRoomVersion - response.LatestEvents = r.LatestEvents - response.StateEvents = r.StateEvents - response.Depth = r.Depth - return nil -} - -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryStateAfterEvents(request) - response.PrevEventsExist = res.PrevEventsExist - response.RoomExists = res.RoomExists - response.StateEvents = res.StateEvents - return nil -} - -// Query a list of events by event ID. -func (t *testRoomserverAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - res := t.queryEventsByID(request) - response.Events = res.Events - return nil -} - -// Query if a server is joined to a room -func (t *testRoomserverAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - response.RoomExists = true - response.IsInRoom = true - return nil -} - -// Asks for the room version for a given room. -func (t *testRoomserverAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil -} - -func (t *testRoomserverAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - res.Banned = false - return nil -} - -type txnFedClient struct { - state map[string]gomatrixserverlib.RespState // event_id to response - stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response - getEvent map[string]gomatrixserverlib.Transaction // event_id to response - getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, -) { - fmt.Println("testFederationClient.LookupState", eventID) - r, ok := c.state[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { - fmt.Println("testFederationClient.LookupStateIDs", eventID) - r, ok := c.stateIDs[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { - fmt.Println("testFederationClient.GetEvent", eventID) - r, ok := c.getEvent[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { - return c.getMissingEvents(missing) -} - -func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { - t := &txnReq{ - rsAPI: rsAPI, - keys: &test.NopJSONVerifier{}, - federation: fedClient, - roomsMu: internal.NewMutexByRoom(), - } - t.PDUs = pdus - t.Origin = testOrigin - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - t.Destination = testDestination - return t -} - -func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) - if err != nil { - t.Errorf("txn.processTransaction returned an error: %v", err) - return - } - if len(res.PDUs) != len(txn.PDUs) { - t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) - return - } -NextPDU: - for eventID, result := range res.PDUs { - if result.Error == "" { - continue - } - for _, eventIDWantError := range pdusWithErrors { - if eventID == eventIDWantError { - break NextPDU - } - } - t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) - } -} - -/* -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { -NextTuple: - for _, t := range tuples { - for _, o := range omitTuples { - if t == o { - break NextTuple - } - } - h, ok := testStateEvents[t] - if ok { - result = append(result, h) - } - } - return -} -*/ - -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { - for _, g := range got { - fmt.Println("GOT ", g.Event.EventID()) - } - if len(got) != len(want) { - t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) - return - } - for i := range got { - if got[i].Event.EventID() != want[i].EventID() { - t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) - } - } -} - -// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on -// to the roomserver. It's the most basic test possible. -func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver -// as it does the auth check. -func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{}) - // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, -// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, -// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in -// topological order and sent to the roomserver. -/* -func TestTransactionFetchMissingPrevEvents(t *testing.T) { - haveEvent := testEvents[len(testEvents)-3] - prevEvent := testEvents[len(testEvents)-2] - inputEvent := testEvents[len(testEvents)-1] - - var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions - rsAPI = &testRoomserverAPI{ - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - res := api.QueryEventsByIDResponse{} - for _, ev := range testEvents { - for _, id := range req.EventIDs { - if ev.EventID() == id { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - StateEvents: testEvents[:5], - } - }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: haveEvent.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - haveEvent.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, nil), - } - }, - } - - cli := &txnFedClient{ - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) - } - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - prevEvent.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - inputEvent.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) -} - -// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill -// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids -// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in -// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite -// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to -// continue. The DAG looks something like: -// FE GME TXN -// A ---> B ---> C ---> D -// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: -// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. -// - state resolution is done to check C is allowed. -// This results in B being sent as an outlier FIRST, then C,D. -func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { - eventA := testEvents[len(testEvents)-5] - // this is also len(testEvents)-4 - eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }] - eventC := testEvents[len(testEvents)-3] - eventD := testEvents[len(testEvents)-2] - fmt.Println("a:", eventA.EventID()) - fmt.Println("b:", eventB.EventID()) - fmt.Println("c:", eventC.EventID()) - fmt.Println("d:", eventD.EventID()) - var rsAPI *testRoomserverAPI - rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }, - } - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - omitTuples = nil // include event B now - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - var stateEvents []*gomatrixserverlib.HeaderedEvent - if prevEventExists { - stateEvents = fromStateTuples(req.StateToFetch, omitTuples) - } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventExists, - RoomExists: true, - StateEvents: stateEvents, - } - }, - - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - - var missingPrevEvent []string - if !prevEventExists { - missingPrevEvent = []string{"test"} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, - } - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: eventA.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - eventA.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, omitTuples), - } - }, - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - var res api.QueryEventsByIDResponse - fmt.Println("queryEventsByID ", req.EventIDs) - for _, wantEventID := range req.EventIDs { - for _, ev := range testStateEvents { - // roomserver is missing the power levels event unless it's been sent to us recently as an outlier - if wantEventID == eventB.EventID() { - fmt.Println("Asked for pl event") - for _, inEv := range rsAPI.inputRoomEvents { - fmt.Println("recv ", inEv.Event.EventID()) - if inEv.Event.EventID() == wantEventID { - res.Events = append(res.Events, inEv.Event) - break - } - } - continue - } - if ev.EventID() == wantEventID { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - } - // /state_ids for event B returns every state event but B (it's the state before) - var authEventIDs []string - var stateEventIDs []string - for _, ev := range testStateEvents { - if ev.EventID() == eventB.EventID() { - continue - } - // state res checks what auth events you give it, and this isn't a valid auth event - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { - authEventIDs = append(authEventIDs, ev.EventID()) - } - stateEventIDs = append(stateEventIDs, ev.EventID()) - } - cli := &txnFedClient{ - stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): { - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, - }, - }, - // /event for event B returns it - getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): { - PDUs: []json.RawMessage{ - eventB.JSON(), - }, - }, - }, - // /get_missing_events should be done exactly once - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) - } - // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - eventC.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - eventD.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) -} -*/ diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 1d08d0a82..1120cf260 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -107,7 +107,7 @@ func getState( return nil, nil, err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return nil, nil, resErr } diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 2ba99112c..e29e3b140 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -28,25 +29,30 @@ type Statistics struct { // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 + + // How many times should we tolerate consecutive failures before we + // mark the destination as offline. At this point we should attempt + // to send messages to the user's async relay servers if we know them. + FailuresUntilAssumedOffline uint32 } -func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { +func NewStatistics( + db storage.Database, + failuresUntilBlacklist uint32, + failuresUntilAssumedOffline uint32, +) Statistics { return Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + FailuresUntilAssumedOffline: failuresUntilAssumedOffline, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { - // If the map hasn't been initialised yet then do that. - if s.servers == nil { - s.mutex.Lock() - s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics) - s.mutex.Unlock() - } // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] @@ -55,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS if !found { s.mutex.Lock() server = &ServerStatistics{ - statistics: s, - serverName: serverName, + statistics: s, + serverName: serverName, + knownRelayServers: []gomatrixserverlib.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -66,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) + } else { + server.assumedOffline.Store(assumedOffline) + } + + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) + } else { + server.relayMutex.Lock() + server.knownRelayServers = knownRelayServers + server.relayMutex.Unlock() + } } return server } +type SendMethod uint8 + +const ( + SendDirect SendMethod = iota + SendViaRelay +) + // ServerStatistics contains information about our interactions with a // remote federated host, e.g. how many times we were successful, how // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes - notifierMutex sync.Mutex + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex + knownRelayServers []gomatrixserverlib.ServerName + relayMutex sync.Mutex } const maxJitterMultiplier = 1.4 @@ -118,14 +150,22 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. -func (s *ServerStatistics) Success() { +// `relay` specifies whether the success was to the actual destination +// or one of their relay servers. +func (s *ServerStatistics) Success(method SendMethod) { s.cancel() s.backoffCount.Store(0) - s.successCounter.Inc() - if s.statistics.DB != nil { - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + // NOTE : Sending to the final destination vs. a relay server has + // slightly different semantics. + if method == SendDirect { + s.successCounter.Inc() + if s.blacklisted.Load() && s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } + + s.removeAssumedOffline() } } @@ -144,7 +184,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { - if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + backoffCount := s.backoffCount.Inc() + + if backoffCount >= s.statistics.FailuresUntilAssumedOffline { + s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } + } + + if backoffCount >= s.statistics.FailuresUntilBlacklist { s.blacklisted.Store(true) if s.statistics.DB != nil { if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { @@ -162,13 +213,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { s.backoffUntil.Store(until) s.statistics.backoffMutex.Lock() - defer s.statistics.backoffMutex.Unlock() s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) + s.statistics.backoffMutex.Unlock() } return s.backoffUntil.Load().(time.Time), false } +// MarkServerAlive removes the assumed offline and blacklisted statuses from this server. +// Returns whether the server was blacklisted before this point. +func (s *ServerStatistics) MarkServerAlive() bool { + s.removeAssumedOffline() + wasBlacklisted := s.removeBlacklist() + return wasBlacklisted +} + // ClearBackoff stops the backoff timer for this destination if it is running // and removes the timer from the backoffTimers map. func (s *ServerStatistics) ClearBackoff() { @@ -196,13 +255,13 @@ func (s *ServerStatistics) backoffFinished() { } // BackoffInfo returns information about the current or previous backoff. -// Returns the last backoffUntil time and whether the server is currently blacklisted or not. -func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { +// Returns the last backoffUntil time. +func (s *ServerStatistics) BackoffInfo() *time.Time { until, ok := s.backoffUntil.Load().(time.Time) if ok { - return &until, s.blacklisted.Load() + return &until } - return nil, s.blacklisted.Load() + return nil } // Blacklisted returns true if the server is blacklisted and false @@ -211,10 +270,33 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { +// AssumedOffline returns true if the server is assumed offline and false +// otherwise. +func (s *ServerStatistics) AssumedOffline() bool { + return s.assumedOffline.Load() +} + +// removeBlacklist removes the blacklisted status from the server. +// Returns whether the server was blacklisted. +func (s *ServerStatistics) removeBlacklist() bool { + var wasBlacklisted bool + + if s.Blacklisted() { + wasBlacklisted = true + _ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName) + } s.cancel() s.backoffCount.Store(0) + + return wasBlacklisted +} + +// removeAssumedOffline removes the assumed offline status from the server. +func (s *ServerStatistics) removeAssumedOffline() { + if s.AssumedOffline() { + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) + } + s.assumedOffline.Store(false) } // SuccessCount returns the number of successful requests. This is @@ -222,3 +304,46 @@ func (s *ServerStatistics) RemoveBlacklist() { func (s *ServerStatistics) SuccessCount() uint32 { return s.successCounter.Load() } + +// KnownRelayServers returns the list of relay servers associated with this +// server. +func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { + s.relayMutex.Lock() + defer s.relayMutex.Unlock() + return s.knownRelayServers +} + +func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { + seenSet := make(map[gomatrixserverlib.ServerName]bool) + uniqueList := []gomatrixserverlib.ServerName{} + for _, srv := range relayServers { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) + if err != nil { + logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) + return + } + + for _, newServer := range uniqueList { + alreadyKnown := false + knownRelayServers := s.KnownRelayServers() + for _, srv := range knownRelayServers { + if srv == newServer { + alreadyKnown = true + } + } + if !alreadyKnown { + { + s.relayMutex.Lock() + s.knownRelayServers = append(s.knownRelayServers, newServer) + s.relayMutex.Unlock() + } + } + } +} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 6aa997f44..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -4,17 +4,26 @@ import ( "math" "testing" "time" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 ) func TestBackoff(t *testing.T) { - stats := NewStatistics(nil, 7) + stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{ statistics: &stats, serverName: "test.com", } // Start by checking that counting successes works. - server.Success() + server.Success(SendDirect) if successes := server.SuccessCount(); successes != 1 { t.Fatalf("Expected success count 1, got %d", successes) } @@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) { // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - - // Get the duration. - _, blacklist := server.BackoffInfo() + blacklist := server.Blacklisted() + assumedOffline := server.AssumedOffline() duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that @@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) { server.cancel() server.backoffStarted.Store(false) + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i) + } + } + + // Check if we should be assumed offline by now. + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i) + } else { + t.Logf("Backoff %d is assumed offline as expected", i) + } + } else { + if assumedOffline { + t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i) + } else { + t.Logf("Backoff %d is not assumed offline as expected", i) + } + } + // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffInfo and Failure returned different blacklist values") + t.Fatalf("Blacklisted and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue } + } else { + if blacklist { + t.Fatalf("Backoff %d should not have resulted in blacklist but did", i) + } else { + t.Logf("Backoff %d is not blacklisted as expected", i) + } } // Check if the duration is what we expect. @@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) { } } } + +func TestRelayServersListing(t *testing.T) { + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + server := ServerStatistics{statistics: &stats} + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers := server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers = server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) +} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b15b8bfae..4f5300af1 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,11 +20,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) type Database interface { + P2PDatabase gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) @@ -34,19 +35,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - - GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) @@ -57,6 +55,18 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. + // If the server already exists in the table, nothing happens and returns success. + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Removes the server from the list of assumed offline servers. + // If the server doesn't exist in the table, nothing happens and returns success. + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Purges all entries from the assumed offline table. + RemoveAllServersAssumedOffline(ctx context.Context) error + // Gets whether the provided server is present in the table. + // If it is present, returns true. If not, returns false. + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) @@ -74,4 +84,24 @@ type Database interface { GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error + + PurgeRoom(ctx context.Context, roomID string) error +} + +type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error } diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..5695d2e54 --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index df5c60761..ad2afcb15 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index c22d893f7..5df684318 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index d6507e13b..8870dc88d 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -81,7 +77,6 @@ type queueEDUsStatements struct { deleteQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index 38ac5a6eb..3b0bef9af 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -71,7 +67,6 @@ type queuePDUsStatements struct { deleteQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt } @@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } @@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go new file mode 100644 index 000000000..f7267978f --- /dev/null +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + deleteRelayServersStmt *sql.Stmt + deleteAllRelayServersStmt *sql.Stmt +} + +func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteRelayServersStmt, deleteRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index fe84e932e..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewPostgresRelayServersTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..b347269c1 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1e1ea9e17..6769637bc 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -37,6 +38,8 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -44,22 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -149,18 +143,21 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) + }) +} + +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) +} + +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) +} + +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) + }) +} + +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) @@ -259,3 +373,18 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge joined hosts: %w", err) + } + if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge inbound peeks: %w", err) + } + if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge outbound peeks: %w", err) + } + return nil + }) +} diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index c796d2f8f..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -162,15 +165,6 @@ func (d *Database) CleanEDUs( }) } -// GetPendingEDUCount returns the number of EDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingEDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have EDUs // waiting to be sent. func (d *Database) GetPendingEDUServerNames( diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index dc37d7507..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -141,15 +144,6 @@ func (d *Database) CleanPDUs( }) } -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have PDUs // waiting to be sent. func (d *Database) GetPendingPDUServerNames( diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..ff2afb4da --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index ad3c4a6dd..8c3567934 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index e29026fab..33f452b68 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 8e7e7901f..0dc914328 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -82,7 +78,6 @@ type queueEDUsStatements struct { // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.insertQueueEDUStmt, insertQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index e818585a5..aee8b03d6 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -79,7 +75,6 @@ type queuePDUsStatements struct { selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic } @@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } @@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go new file mode 100644 index 000000000..27c3cca2c --- /dev/null +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic + deleteAllRelayServersStmt *sql.Stmt +} + +func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index d13b5defc..1e7e41a2c 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewSQLiteRelayServersTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f7408fa9f..1d2a13e81 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -2,16 +2,17 @@ package storage_test import ( "context" + "reflect" "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -80,3 +81,263 @@ func TestExpireEDUs(t *testing.T) { assert.Equal(t, 2, len(data)) }) } + +func TestOutboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add outbound peek + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) + } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) + }) +} + +func TestInboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add inbound peek + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) + } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) + }) +} + +func TestServersAssumedOffline(t *testing.T) { + server1 := gomatrixserverlib.ServerName("server1") + server2 := gomatrixserverlib.ServerName("server2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + // Set server1 & server2 as assumed offline. + err := db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + err = db.SetServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + + // Ensure both servers are assumed offline. + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Set server1 as not assumed offline. + err = db.RemoveServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Re-set server1 as assumed offline. + err = db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure server1 is assumed offline. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + + err = db.RemoveAllServersAssumedOffline(context.Background()) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.False(t, isOffline) + }) +} + +func TestRelayServersStored(t *testing.T) { + server := gomatrixserverlib.ServerName("server") + relayServer1 := gomatrixserverlib.ServerName("relayserver1") + relayServer2 := gomatrixserverlib.ServerName("relayserver2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + assert.Equal(t, relayServer2, relayServers[1]) + + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + }) +} diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go new file mode 100644 index 000000000..e5d898b3a --- /dev/null +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -0,0 +1,149 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" +) + +func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationInboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresInboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteInboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestInboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateInboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) + } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) + + // And delete them again + if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) > 0 { + t.Fatal("got inbound peeks which should be deleted") + } + + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 9f4e86a6e..762504e45 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -28,7 +28,6 @@ type FederationQueuePDUs interface { InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) } @@ -38,7 +37,6 @@ type FederationQueueEDUs interface { DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error @@ -51,6 +49,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error @@ -68,6 +79,20 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + +type FederationRelayServers interface { + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go new file mode 100644 index 000000000..a460af09d --- /dev/null +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -0,0 +1,148 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" +) + +func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationOutboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresOutboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestOutboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateOutboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) + } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) + + // And delete them again + if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) > 0 { + t.Fatal("got outbound peeks which should be deleted") + } + }) +} diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go new file mode 100644 index 000000000..b41211551 --- /dev/null +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -0,0 +1,224 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" + server4 = "server4" +) + +type RelayServersDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationRelayServers +} + +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationRelayServers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayServersTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayServersTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayServersDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + +func TestShouldDeleteCorrectRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } + + expectedRelayServers := []gomatrixserverlib.ServerName{server3} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldDeleteAllRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + + expectedRelayServers1 := []gomatrixserverlib.ServerName{} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} diff --git a/go.mod b/go.mod index 6d00e80dc..80d8b73f9 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ require ( github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/blevesearch/bleve/v2 v2.3.4 + github.com/blevesearch/bleve/v2 v2.3.6 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible @@ -16,18 +16,16 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 - github.com/gorilla/websocket v1.5.0 github.com/kardianos/minwinsvc v1.0.2 github.com/lib/pq v1.10.7 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 - github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 - github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 - github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.8 - github.com/nats-io/nats.go v1.20.0 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 + github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 + github.com/mattn/go-sqlite3 v1.14.16 + github.com/nats-io/nats-server/v2 v2.9.15 + github.com/nats-io/nats.go v1.24.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 @@ -37,69 +35,64 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.1.0 - golang.org/x/image v0.1.0 + golang.org/x/crypto v0.6.0 + golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/net v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/term v0.5.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 modernc.org/sqlite v1.19.3 - nhooyr.io/websocket v1.8.7 ) -require github.com/matryer/is v1.4.0 +require ( + github.com/MFAshby/stdemuxerhook v1.0.0 + github.com/matryer/is v1.4.0 +) require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect - github.com/RoaringBitmap/roaring v1.2.1 // indirect + github.com/RoaringBitmap/roaring v1.2.3 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.3.3 // indirect - github.com/blevesearch/bleve_index_api v1.0.3 // indirect - github.com/blevesearch/geo v0.1.14 // indirect + github.com/bits-and-blooms/bitset v1.5.0 // indirect + github.com/blevesearch/bleve_index_api v1.0.5 // indirect + github.com/blevesearch/geo v0.1.17 // indirect github.com/blevesearch/go-porterstemmer v1.0.3 // indirect github.com/blevesearch/gtreap v0.1.1 // indirect github.com/blevesearch/mmap-go v1.0.4 // indirect - github.com/blevesearch/scorch_segment_api/v2 v2.1.2 // indirect - github.com/blevesearch/segment v0.9.0 // indirect + github.com/blevesearch/scorch_segment_api/v2 v2.1.4 // indirect + github.com/blevesearch/segment v0.9.1 // indirect github.com/blevesearch/snowballstem v0.9.0 // indirect - github.com/blevesearch/upsidedown_store_api v1.0.1 // indirect - github.com/blevesearch/vellum v1.0.8 // indirect - github.com/blevesearch/zapx/v11 v11.3.5 // indirect - github.com/blevesearch/zapx/v12 v12.3.5 // indirect - github.com/blevesearch/zapx/v13 v13.3.5 // indirect - github.com/blevesearch/zapx/v14 v14.3.5 // indirect - github.com/blevesearch/zapx/v15 v15.3.5 // indirect + github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect + github.com/blevesearch/vellum v1.0.9 // indirect + github.com/blevesearch/zapx/v11 v11.3.7 // indirect + github.com/blevesearch/zapx/v12 v12.3.7 // indirect + github.com/blevesearch/zapx/v13 v13.3.7 // indirect + github.com/blevesearch/zapx/v14 v14.3.7 // indirect + github.com/blevesearch/zapx/v15 v15.3.8 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/distribution v2.8.1+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/gogo/protobuf v1.1.1 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect github.com/golang/glog v1.0.0 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/juju/errors v1.0.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.15.11 // indirect + github.com/klauspost/compress v1.16.0 // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/lucas-clemente/quic-go v0.30.0 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.3 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/minio/highwayhash v1.0.2 // indirect @@ -111,8 +104,6 @@ require ( github.com/nats-io/jwt/v2 v2.3.0 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/onsi/ginkgo/v2 v2.3.0 // indirect - github.com/onsi/gomega v1.22.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -121,13 +112,13 @@ require ( github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect - golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect - golang.org/x/time v0.1.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect + golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 11e1a5de7..ab56e676f 100644 --- a/go.sum +++ b/go.sum @@ -43,14 +43,15 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20O github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= +github.com/MFAshby/stdemuxerhook v1.0.0 h1:1XFGzakrsHMv76AeanPDL26NOgwjPl/OUxbGhJthwMc= +github.com/MFAshby/stdemuxerhook v1.0.0/go.mod h1:nLMI9FUf9Hz98n+yAXsTMUR4RZQy28uCTLG1Fzvj/uY= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= -github.com/RoaringBitmap/roaring v0.9.4/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= -github.com/RoaringBitmap/roaring v1.2.1 h1:58/LJlg/81wfEHd5L9qsHduznOIhyv4qb1yWcSvVq9A= -github.com/RoaringBitmap/roaring v1.2.1/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= +github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= +github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -68,51 +69,45 @@ github.com/anacrolix/missinggo v1.2.1 h1:0IE3TqX5y5D0IxeMwTyIgqdDew4QrzcXaaEnJQy github.com/anacrolix/missinggo v1.2.1/go.mod h1:J5cMhif8jPmFoC3+Uvob3OXXNIhOUikzMt+uUjeM21Y= github.com/anacrolix/missinggo/perf v1.0.0/go.mod h1:ljAFWkBuzkO12MQclXzZrosP5urunoLS0Cbvb4V0uMQ= github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/bits-and-blooms/bitset v1.3.3 h1:R1XWiopGiXf66xygsiLpzLo67xEYvMkHw3w+rCOSAwg= -github.com/bits-and-blooms/bitset v1.3.3/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/blevesearch/bleve/v2 v2.3.4 h1:SSb7/cwGzo85LWX1jchIsXM8ZiNNMX3shT5lROM63ew= -github.com/blevesearch/bleve/v2 v2.3.4/go.mod h1:Ot0zYum8XQRfPcwhae8bZmNyYubynsoMjVvl1jPqL30= -github.com/blevesearch/bleve_index_api v1.0.3 h1:DDSWaPXOZZJ2BB73ZTWjKxydAugjwywcqU+91AAqcAg= -github.com/blevesearch/bleve_index_api v1.0.3/go.mod h1:fiwKS0xLEm+gBRgv5mumf0dhgFr2mDgZah1pqv1c1M4= -github.com/blevesearch/geo v0.1.13/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/geo v0.1.14 h1:TTDpJN6l9ck/cUYbXSn4aCElNls0Whe44rcQKsB7EfU= -github.com/blevesearch/geo v0.1.14/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= +github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= +github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= +github.com/blevesearch/bleve/v2 v2.3.6 h1:NlntUHcV5CSWIhpugx4d/BRMGCiaoI8ZZXrXlahzNq4= +github.com/blevesearch/bleve/v2 v2.3.6/go.mod h1:JM2legf1cKVkdV8Ehu7msKIOKC0McSw0Q16Fmv9vsW4= +github.com/blevesearch/bleve_index_api v1.0.5 h1:Lc986kpC4Z0/n1g3gg8ul7H+lxgOQPcXb9SxvQGu+tw= +github.com/blevesearch/bleve_index_api v1.0.5/go.mod h1:YXMDwaXFFXwncRS8UobWs7nvo0DmusriM1nztTlj1ms= +github.com/blevesearch/geo v0.1.17 h1:AguzI6/5mHXapzB0gE9IKWo+wWPHZmXZoscHcjFgAFA= +github.com/blevesearch/geo v0.1.17/go.mod h1:uRMGWG0HJYfWfFJpK3zTdnnr1K+ksZTuWKhXeSokfnM= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M= -github.com/blevesearch/goleveldb v1.0.1/go.mod h1:WrU8ltZbIp0wAoig/MHbrPCXSOLpe79nz5lv5nqfYrQ= github.com/blevesearch/gtreap v0.1.1 h1:2JWigFrzDMR+42WGIN/V2p0cUvn4UP3C4Q5nmaZGW8Y= github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgYICSZ3w0tYk= -github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2 h1:TAte9VZLWda5WAVlZTTZ+GCzEHqGJb4iB2aiZSA6Iv8= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2/go.mod h1:rvoQXZGq8drq7vXbNeyiRzdEOwZkjkiYGf1822i6CRA= -github.com/blevesearch/segment v0.9.0 h1:5lG7yBCx98or7gK2cHMKPukPZ/31Kag7nONpoBt22Ac= -github.com/blevesearch/segment v0.9.0/go.mod h1:9PfHYUdQCgHktBgvtUOF4x+pc4/l8rdH0u5spnW85UQ= -github.com/blevesearch/snowball v0.6.1/go.mod h1:ZF0IBg5vgpeoUhnMza2v0A/z8m1cWPlwhke08LpNusg= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4 h1:LmGmo5twU3gV+natJbKmOktS9eMhokPGKWuR+jX84vk= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4/go.mod h1:PgVnbbg/t1UkgezPDu8EHLi1BHQ17xUwsFdU6NnOYS0= +github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= +github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s= github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs= -github.com/blevesearch/upsidedown_store_api v1.0.1 h1:1SYRwyoFLwG3sj0ed89RLtM15amfX2pXlYbFOnF8zNU= -github.com/blevesearch/upsidedown_store_api v1.0.1/go.mod h1:MQDVGpHZrpe3Uy26zJBf/a8h0FZY6xJbthIMm8myH2Q= -github.com/blevesearch/vellum v1.0.8 h1:iMGh4lfxza4BnWO/UJTMPlI3HsK9YawjPv+TteVa9ck= -github.com/blevesearch/vellum v1.0.8/go.mod h1:+cpRi/tqq49xUYSQN2P7A5zNSNrS+MscLeeaZ3J46UA= -github.com/blevesearch/zapx/v11 v11.3.5 h1:eBQWQ7huA+mzm0sAGnZDwgGGli7S45EO+N+ObFWssbI= -github.com/blevesearch/zapx/v11 v11.3.5/go.mod h1:5UdIa/HRMdeRCiLQOyFESsnqBGiip7vQmYReA9toevU= -github.com/blevesearch/zapx/v12 v12.3.5 h1:5pX2hU+R1aZihT7ac1dNWh1n4wqkIM9pZzWp0ANED9s= -github.com/blevesearch/zapx/v12 v12.3.5/go.mod h1:ANcthYRZQycpbRut/6ArF5gP5HxQyJqiFcuJCBju/ss= -github.com/blevesearch/zapx/v13 v13.3.5 h1:eJ3gbD+Nu8p36/O6lhfdvWQ4pxsGYSuTOBrLLPVWJ74= -github.com/blevesearch/zapx/v13 v13.3.5/go.mod h1:FV+dRnScFgKnRDIp08RQL4JhVXt1x2HE3AOzqYa6fjo= -github.com/blevesearch/zapx/v14 v14.3.5 h1:hEvVjZaagFCvOUJrlFQ6/Z6Jjy0opM3g7TMEo58TwP4= -github.com/blevesearch/zapx/v14 v14.3.5/go.mod h1:954A/eKFb+pg/ncIYWLWCKY+mIjReM9FGTGIO2Wu1cU= -github.com/blevesearch/zapx/v15 v15.3.5 h1:NVD0qq8vRk66ImJn1KloXT5ckqPDUZT7VbVJs9jKlac= -github.com/blevesearch/zapx/v15 v15.3.5/go.mod h1:QMUh2hXCaYIWFKPYGavq/Iga2zbHWZ9DZAa9uFbWyvg= +github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A= +github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= +github.com/blevesearch/vellum v1.0.9 h1:PL+NWVk3dDGPCV0hoDu9XLLJgqU4E5s/dOeEJByQ2uQ= +github.com/blevesearch/vellum v1.0.9/go.mod h1:ul1oT0FhSMDIExNjIxHqJoGpVrBpKCdgDQNxfqgJt7k= +github.com/blevesearch/zapx/v11 v11.3.7 h1:Y6yIAF/DVPiqZUA/jNgSLXmqewfzwHzuwfKyfdG+Xaw= +github.com/blevesearch/zapx/v11 v11.3.7/go.mod h1:Xk9Z69AoAWIOvWudNDMlxJDqSYGf90LS0EfnaAIvXCA= +github.com/blevesearch/zapx/v12 v12.3.7 h1:DfQ6rsmZfEK4PzzJJRXjiM6AObG02+HWvprlXQ1Y7eI= +github.com/blevesearch/zapx/v12 v12.3.7/go.mod h1:SgEtYIBGvM0mgIBn2/tQE/5SdrPXaJUaT/kVqpAPxm0= +github.com/blevesearch/zapx/v13 v13.3.7 h1:igIQg5eKmjw168I7av0Vtwedf7kHnQro/M+ubM4d2l8= +github.com/blevesearch/zapx/v13 v13.3.7/go.mod h1:yyrB4kJ0OT75UPZwT/zS+Ru0/jYKorCOOSY5dBzAy+s= +github.com/blevesearch/zapx/v14 v14.3.7 h1:gfe+fbWslDWP/evHLtp/GOvmNM3sw1BbqD7LhycBX20= +github.com/blevesearch/zapx/v14 v14.3.7/go.mod h1:9J/RbOkqZ1KSjmkOes03AkETX7hrXT0sFMpWH4ewC4w= +github.com/blevesearch/zapx/v15 v15.3.8 h1:q4uMngBHzL1IIhRc8AJUEkj6dGOE3u1l3phLu7hq8uk= +github.com/blevesearch/zapx/v15 v15.3.8/go.mod h1:m7Y6m8soYUvS7MjN9eKlz1xrLCcmqfFadmu7GhWIrLY= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 h1:GKTyiRCL6zVf5wWaqKnf+7Qs6GbEPfd4iMOitWzXJx8= @@ -128,12 +123,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI= github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= -github.com/couchbase/moss v0.2.0/go.mod h1:9MaHIaRuy9pvLPUJxB8sh8OrLfyDczECVL37grCIubs= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -162,13 +151,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/sentry-go v0.14.0 h1:rlOBkuFZRKKdUnKO+0U3JclRDQKlRu5vVQtkWSQvC70= github.com/getsentry/sentry-go v0.14.0/go.mod h1:RZPJKSw+adu8PBNygiri/A98FqVr2HtRckJk9XVxJ9I= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= github.com/glycerine/goconvey v0.0.0-20180728074245-46e3a41ad493/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= @@ -183,23 +167,7 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= -github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= -github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= -github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= -github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ= @@ -220,8 +188,6 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -240,7 +206,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= @@ -260,7 +225,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -270,8 +234,6 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -280,25 +242,16 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0 h1:pO2K/gKgKaat5LdpAhxhluX2GPQMaI3W5FUz/I/UnWk= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/json-iterator/go v0.0.0-20171115153421-f7279a603ede/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -316,9 +269,8 @@ github.com/kardianos/minwinsvc v1.0.2/go.mod h1:LUZNYhNmxujx2tR7FbdxqYJ9XDDoCd3M github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c= -github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= +github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4= +github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -329,44 +281,30 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= -github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= -github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= -github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 h1:JSw0nmjMrgBmoM2aQsa78LTpI5BnuD9+vOiEQ4Qo0qw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= -github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae h1:O4SWKdcHVCvYqyDV+9CJA1fcDN2L11Bule0iFy3YlAI= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -385,10 +323,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= -github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= -github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= -github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= +github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c= +github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE= +github.com/nats-io/nats.go v1.24.0 h1:CRiD8L5GOQu/DcfkmgBcTTIQORMwizF+rPk6T0RaHVQ= +github.com/nats-io/nats.go v1.24.0/go.mod h1:dVQF+BK3SzUZpwyzHedXsvH3EO38aVKuOPkkHlv5hXA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -400,13 +338,6 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 h1:Dmx8g2747UTVPzSkmohk84S3g/uWqd6+f4SSLPhLcfA= github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigBfFfHDWsklmo0T7Ixbg0XXgck+Hq4O9k= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= -github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.22.1 h1:pY8O4lBfsHKZHM/6nrxkhVPUznOlIu3quZcKP/M20KI= -github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec= @@ -415,9 +346,6 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -458,7 +386,6 @@ github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa/go.mod h1:qq github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -468,12 +395,7 @@ github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0 github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -483,19 +405,19 @@ github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -503,22 +425,14 @@ github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaO github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= -github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yggdrasil-network/yggdrasil-go v0.4.6 h1:GALUDV9QPz/5FVkbazpkTc9EABHufA556JwUJZr41j4= github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA9QgCS6Bu+9V+1VkWM84wpw= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -530,7 +444,6 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -539,8 +452,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -554,13 +467,11 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.1.0 h1:r8Oj8ZA2Xy12/b5KZYj3tuv7NG/fBz3TwQVvpJ9l8Rk= -golang.org/x/image v0.1.0/go.mod h1:iyPr49SD/G/TBxYVB/9RRtGUT5eNbo2u4NamWeQcD5c= +golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= +golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -581,13 +492,11 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -616,13 +525,12 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -640,14 +548,10 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181221143128-b4a75ba826a6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -664,7 +568,6 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -684,9 +587,7 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -697,12 +598,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -710,13 +611,13 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= -golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -761,7 +662,6 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= @@ -857,18 +757,15 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/h2non/bimg.v1 v1.1.9 h1:wZIUbeOnwr37Ta4aofhIv8OI8v4ujpjXC9mXnAGpQjM= gopkg.in/h2non/bimg.v1 v1.1.9/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= @@ -909,8 +806,6 @@ modernc.org/tcl v1.15.0 h1:oY+JeD11qVVSgVvodMJsu7Edf8tr5E/7tuhF5cNYz34= modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= -nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= -nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/helm/cr.yaml b/helm/cr.yaml new file mode 100644 index 000000000..884c2b46b --- /dev/null +++ b/helm/cr.yaml @@ -0,0 +1,2 @@ +release-name-template: "helm-{{ .Name }}-{{ .Version }}" +pages-index-path: docs/index.yaml \ No newline at end of file diff --git a/helm/ct.yaml b/helm/ct.yaml new file mode 100644 index 000000000..af706fa3d --- /dev/null +++ b/helm/ct.yaml @@ -0,0 +1,7 @@ +remote: origin +target-branch: main +chart-repos: + - bitnami=https://charts.bitnami.com/bitnami +chart-dirs: + - helm +validate-maintainers: false \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/about.gotmpl b/helm/dendrite/.helm-docs/about.gotmpl new file mode 100644 index 000000000..a92c6be42 --- /dev/null +++ b/helm/dendrite/.helm-docs/about.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.about" }} +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/appservices.gotmpl b/helm/dendrite/.helm-docs/appservices.gotmpl new file mode 100644 index 000000000..8a79a0780 --- /dev/null +++ b/helm/dendrite/.helm-docs/appservices.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.appservices" }} +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/database.gotmpl b/helm/dendrite/.helm-docs/database.gotmpl new file mode 100644 index 000000000..85ef01ecc --- /dev/null +++ b/helm/dendrite/.helm-docs/database.gotmpl @@ -0,0 +1,18 @@ +{{ define "chart.dbCreation" }} +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/monitoring.gotmpl b/helm/dendrite/.helm-docs/monitoring.gotmpl new file mode 100644 index 000000000..3618a1c1a --- /dev/null +++ b/helm/dendrite/.helm-docs/monitoring.gotmpl @@ -0,0 +1,22 @@ +{{ define "chart.monitoringSection" }} +## Monitoring + +[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) + +* Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: +```yaml +prometheus: + servicemonitor: + enabled: true + labels: + release: "kube-prometheus-stack" + rules: + enabled: true # will deploy alert rules + labels: + release: "kube-prometheus-stack" +grafana: + dashboards: + enabled: true # will deploy default dashboards +``` +PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/state.gotmpl b/helm/dendrite/.helm-docs/state.gotmpl new file mode 100644 index 000000000..2fe987ddd --- /dev/null +++ b/helm/dendrite/.helm-docs/state.gotmpl @@ -0,0 +1,3 @@ +{{ define "chart.state" }} +Status: **NOT PRODUCTION READY** +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml new file mode 100644 index 000000000..b352601e8 --- /dev/null +++ b/helm/dendrite/Chart.yaml @@ -0,0 +1,19 @@ +apiVersion: v2 +name: dendrite +version: "0.12.0" +appVersion: "0.12.0" +description: Dendrite Matrix Homeserver +type: application +keywords: + - matrix + - chat + - homeserver + - dendrite +home: https://github.com/matrix-org/dendrite +sources: + - https://github.com/matrix-org/dendrite +dependencies: +- name: postgresql + version: 12.1.7 + repository: https://charts.bitnami.com/bitnami + condition: postgresql.enabled diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md new file mode 100644 index 000000000..c3833edfb --- /dev/null +++ b/helm/dendrite/README.md @@ -0,0 +1,178 @@ +# dendrite + +![Version: 0.12.0](https://img.shields.io/badge/Version-0.12.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.12.0](https://img.shields.io/badge/AppVersion-0.12.0-informational?style=flat-square) +Dendrite Matrix Homeserver + +Status: **NOT PRODUCTION READY** + +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. + +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. + +## Source Code + +* +## Requirements + +| Repository | Name | Version | +|------------|------|---------| +| https://charts.bitnami.com/bitnami | postgresql | 12.1.7 | +## Values + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| image.repository | string | `"ghcr.io/matrix-org/dendrite-monolith"` | Docker repository/image to use | +| image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | +| signing_key.create | bool | `true` | Create a new signing key, if not exists | +| signing_key.existingSecret | string | `""` | Use an existing secret | +| resources | object | sets some sane default values | Default resource requests/limits. | +| persistence.storageClass | string | `""` | The storage class to use for volume claims. Defaults to the cluster default storage class. | +| persistence.jetstream.existingClaim | string | `""` | Use an existing volume claim for jetstream | +| persistence.jetstream.capacity | string | `"1Gi"` | PVC Storage Request for the jetstream volume | +| persistence.media.existingClaim | string | `""` | Use an existing volume claim for media files | +| persistence.media.capacity | string | `"1Gi"` | PVC Storage Request for the media volume | +| persistence.search.existingClaim | string | `""` | Use an existing volume claim for the fulltext search index | +| persistence.search.capacity | string | `"1Gi"` | PVC Storage Request for the search volume | +| dendrite_config.version | int | `2` | | +| dendrite_config.global.server_name | string | `""` | **REQUIRED** Servername for this Dendrite deployment. | +| dendrite_config.global.private_key | string | `"/etc/dendrite/secrets/signing.key"` | The private key to use. (**NOTE**: This is overriden in Helm) | +| dendrite_config.global.well_known_server_name | string | `""` | The server name to delegate server-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.well_known_client_name | string | `""` | The server name to delegate client-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.trusted_third_party_id_servers | list | `["matrix.org","vector.im"]` | Lists of domains that the server will trust as identity servers to verify third party identifiers such as phone numbers and email addresses. | +| dendrite_config.global.old_private_keys | string | `nil` | The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) to old signing keys that were formerly in use on this domain name. These keys will not be used for federation request or event signing, but will be provided to any other homeserver that asks when trying to verify old events. | +| dendrite_config.global.disable_federation | bool | `false` | Disable federation. Dendrite will not be able to make any outbound HTTP requests to other servers and the federation API will not be exposed. | +| dendrite_config.global.key_validity_period | string | `"168h0m0s"` | | +| dendrite_config.global.database.connection_string | string | `""` | The connection string for connections to Postgres. This will be set automatically if using the Postgres dependency | +| dendrite_config.global.database.max_open_conns | int | `90` | Default database maximum open connections | +| dendrite_config.global.database.max_idle_conns | int | `5` | Default database maximum idle connections | +| dendrite_config.global.database.conn_max_lifetime | int | `-1` | Default database maximum lifetime | +| dendrite_config.global.jetstream.storage_path | string | `"/data/jetstream"` | Persistent directory to store JetStream streams in. | +| dendrite_config.global.jetstream.addresses | list | `[]` | NATS JetStream server addresses if not using internal NATS. | +| dendrite_config.global.jetstream.topic_prefix | string | `"Dendrite"` | The prefix for JetStream streams | +| dendrite_config.global.jetstream.in_memory | bool | `false` | Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) | +| dendrite_config.global.jetstream.disable_tls_validation | bool | `true` | Disables TLS validation. This should **NOT** be used in production. | +| dendrite_config.global.cache.max_size_estimated | string | `"1gb"` | The estimated maximum size for the global cache in bytes, or in terabytes, gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or 'kb' suffix is specified. Note that this is not a hard limit, nor is it a memory limit for the entire process. A cache that is too small may ultimately provide little or no benefit. | +| dendrite_config.global.cache.max_age | string | `"1h"` | The maximum amount of time that a cache entry can live for in memory before it will be evicted and/or refreshed from the database. Lower values result in easier admission of new cache entries but may also increase database load in comparison to higher values, so adjust conservatively. Higher values may make it harder for new items to make it into the cache, e.g. if new rooms suddenly become popular. | +| dendrite_config.global.report_stats.enabled | bool | `false` | Configures phone-home statistics reporting. These statistics contain the server name, number of active users and some information on your deployment config. We use this information to understand how Dendrite is being used in the wild. | +| dendrite_config.global.report_stats.endpoint | string | `"https://matrix.org/report-usage-stats/push"` | Endpoint to report statistics to. | +| dendrite_config.global.presence.enable_inbound | bool | `false` | Controls whether we receive presence events from other servers | +| dendrite_config.global.presence.enable_outbound | bool | `false` | Controls whether we send presence events for our local users to other servers. (_May increase CPU/memory usage_) | +| dendrite_config.global.server_notices.enabled | bool | `false` | Server notices allows server admins to send messages to all users on the server. | +| dendrite_config.global.server_notices.local_part | string | `"_server"` | The local part for the user sending server notices. | +| dendrite_config.global.server_notices.display_name | string | `"Server Alerts"` | The display name for the user sending server notices. | +| dendrite_config.global.server_notices.avatar_url | string | `""` | The avatar URL (as a mxc:// URL) name for the user sending server notices. | +| dendrite_config.global.server_notices.room_name | string | `"Server Alerts"` | | +| dendrite_config.global.metrics.enabled | bool | `false` | Whether or not Prometheus metrics are enabled. | +| dendrite_config.global.metrics.basic_auth.user | string | `"metrics"` | HTTP basic authentication username | +| dendrite_config.global.metrics.basic_auth.password | string | `"metrics"` | HTTP basic authentication password | +| dendrite_config.global.dns_cache.enabled | bool | `false` | Whether or not the DNS cache is enabled. | +| dendrite_config.global.dns_cache.cache_size | int | `256` | Maximum number of entries to hold in the DNS cache | +| dendrite_config.global.dns_cache.cache_lifetime | string | `"10m"` | Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.global.profiling.enabled | bool | `false` | Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, as it will only listen on localhost and the defined port. e.g. `kubectl port-forward deployments/dendrite 65432:65432` | +| dendrite_config.global.profiling.port | int | `65432` | pprof port, if enabled | +| dendrite_config.mscs | object | `{"mscs":["msc2946"]}` | Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) | +| dendrite_config.app_service_api.disable_tls_validation | bool | `false` | Disable the validation of TLS certificates of appservices. This is not recommended in production since it may allow appservice traffic to be sent to an insecure endpoint. | +| dendrite_config.app_service_api.config_files | list | `[]` | Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) | +| dendrite_config.client_api.registration_disabled | bool | `true` | Prevents new users from being able to register on this homeserver, except when using the registration shared secret below. | +| dendrite_config.client_api.guests_disabled | bool | `true` | | +| dendrite_config.client_api.registration_shared_secret | string | `""` | If set, allows registration by anyone who knows the shared secret, regardless of whether registration is otherwise disabled. | +| dendrite_config.client_api.enable_registration_captcha | bool | `false` | enable reCAPTCHA registration | +| dendrite_config.client_api.recaptcha_public_key | string | `""` | reCAPTCHA public key | +| dendrite_config.client_api.recaptcha_private_key | string | `""` | reCAPTCHA private key | +| dendrite_config.client_api.recaptcha_bypass_secret | string | `""` | reCAPTCHA bypass secret | +| dendrite_config.client_api.recaptcha_siteverify_api | string | `""` | | +| dendrite_config.client_api.turn.turn_user_lifetime | string | `"24h"` | Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.client_api.turn.turn_uris | list | `[]` | | +| dendrite_config.client_api.turn.turn_shared_secret | string | `""` | | +| dendrite_config.client_api.turn.turn_username | string | `""` | The TURN username | +| dendrite_config.client_api.turn.turn_password | string | `""` | The TURN password | +| dendrite_config.client_api.rate_limiting.enabled | bool | `true` | Enable rate limiting | +| dendrite_config.client_api.rate_limiting.threshold | int | `20` | After how many requests a rate limit should be activated | +| dendrite_config.client_api.rate_limiting.cooloff_ms | int | `500` | Cooloff time in milliseconds | +| dendrite_config.client_api.rate_limiting.exempt_user_ids | string | `nil` | Users which should be exempt from rate limiting | +| dendrite_config.federation_api.send_max_retries | int | `16` | Federation failure threshold. How many consecutive failures that we should tolerate when sending federation requests to a specific server. The backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. The default value is 16 if not specified, which is circa 18 hours. | +| dendrite_config.federation_api.disable_tls_validation | bool | `false` | Disable TLS validation. This should **NOT** be used in production. | +| dendrite_config.federation_api.prefer_direct_fetch | bool | `false` | | +| dendrite_config.federation_api.disable_http_keepalives | bool | `false` | Prevents Dendrite from keeping HTTP connections open for reuse for future requests. Connections will be closed quicker but we may spend more time on TLS handshakes instead. | +| dendrite_config.federation_api.key_perspectives | list | See value.yaml | Perspective keyservers, to use as a backup when direct key fetch requests don't succeed. | +| dendrite_config.media_api.base_path | string | `"/data/media_store"` | The path to store media files (e.g. avatars) in | +| dendrite_config.media_api.max_file_size_bytes | int | `10485760` | The max file size for uploaded media files | +| dendrite_config.media_api.dynamic_thumbnails | bool | `false` | | +| dendrite_config.media_api.max_thumbnail_generators | int | `10` | The maximum number of simultaneous thumbnail generators to run. | +| dendrite_config.media_api.thumbnail_sizes | list | See value.yaml | A list of thumbnail sizes to be generated for media content. | +| dendrite_config.sync_api.real_ip_header | string | `"X-Real-IP"` | This option controls which HTTP header to inspect to find the real remote IP address of the client. This is likely required if Dendrite is running behind a reverse proxy server. | +| dendrite_config.sync_api.search | object | `{"enabled":true,"index_path":"/data/search","language":"en"}` | Configuration for the full-text search engine. | +| dendrite_config.sync_api.search.enabled | bool | `true` | Whether fulltext search is enabled. | +| dendrite_config.sync_api.search.index_path | string | `"/data/search"` | The path to store the search index in. | +| dendrite_config.sync_api.search.language | string | `"en"` | The language most likely to be used on the server - used when indexing, to ensure the returned results match expectations. A full list of possible languages can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) | +| dendrite_config.user_api.bcrypt_cost | int | `10` | bcrypt cost to use when hashing passwords. (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) | +| dendrite_config.user_api.openid_token_lifetime_ms | int | `3600000` | OpenID Token lifetime in milliseconds. | +| dendrite_config.user_api.push_gateway_disable_tls_validation | bool | `false` | | +| dendrite_config.user_api.auto_join_rooms | list | `[]` | Rooms to join users to after registration | +| dendrite_config.logging | list | `[{"level":"info","type":"std"}]` | Default logging configuration | +| postgresql.enabled | bool | See value.yaml | Enable and configure postgres as the database for dendrite. | +| postgresql.image.repository | string | `"bitnami/postgresql"` | | +| postgresql.image.tag | string | `"15.1.0"` | | +| postgresql.auth.username | string | `"dendrite"` | | +| postgresql.auth.password | string | `"changeme"` | | +| postgresql.auth.database | string | `"dendrite"` | | +| postgresql.persistence.enabled | bool | `false` | | +| ingress.enabled | bool | `false` | Create an ingress for a monolith deployment | +| ingress.hosts | list | `[]` | | +| ingress.className | string | `""` | | +| ingress.hostName | string | `""` | | +| ingress.annotations | object | `{}` | Extra, custom annotations | +| ingress.tls | list | `[]` | | +| service.type | string | `"ClusterIP"` | | +| service.port | int | `8008` | | +| prometheus.servicemonitor.enabled | bool | `false` | Enable ServiceMonitor for Prometheus-Operator for scrape metric-endpoint | +| prometheus.servicemonitor.labels | object | `{}` | Extra Labels on ServiceMonitor for selector of Prometheus Instance | +| prometheus.rules.enabled | bool | `false` | Enable PrometheusRules for Prometheus-Operator for setup alerting | +| prometheus.rules.labels | object | `{}` | Extra Labels on PrometheusRules for selector of Prometheus Instance | +| prometheus.rules.additionalRules | list | `[]` | additional alertrules (no default alertrules are provided) | +| grafana.dashboards.enabled | bool | `false` | | +| grafana.dashboards.labels | object | `{"grafana_dashboard":"1"}` | Extra Labels on ConfigMap for selector of grafana sidecar | +| grafana.dashboards.annotations | object | `{}` | Extra Annotations on ConfigMap additional config in grafana sidecar | + +## Monitoring + +[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) + +* Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: +```yaml +prometheus: + servicemonitor: + enabled: true + labels: + release: "kube-prometheus-stack" + rules: + enabled: true # will deploy alert rules + labels: + release: "kube-prometheus-stack" +grafana: + dashboards: + enabled: true # will deploy default dashboards +``` +PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) + diff --git a/helm/dendrite/README.md.gotmpl b/helm/dendrite/README.md.gotmpl new file mode 100644 index 000000000..9411733ce --- /dev/null +++ b/helm/dendrite/README.md.gotmpl @@ -0,0 +1,14 @@ +{{ template "chart.header" . }} +{{ template "chart.deprecationWarning" . }} +{{ template "chart.badgesSection" . }} +{{ template "chart.description" . }} +{{ template "chart.state" . }} +{{ template "chart.about" . }} +{{ template "chart.dbCreation" . }} +{{ template "chart.appservices" . }} +{{ template "chart.maintainersSection" . }} +{{ template "chart.sourcesSection" . }} +{{ template "chart.requirementsSection" . }} +{{ template "chart.valuesSection" . }} +{{ template "chart.monitoringSection" . }} +{{ template "helm-docs.versionFooter" . }} \ No newline at end of file diff --git a/helm/dendrite/ci/ct-ingress-values.yaml b/helm/dendrite/ci/ct-ingress-values.yaml new file mode 100644 index 000000000..f3f58b5ca --- /dev/null +++ b/helm/dendrite/ci/ct-ingress-values.yaml @@ -0,0 +1,18 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + +ingress: + enabled: true + +# dashboard is an ConfigMap with labels - it does not harm on testing +grafana: + dashboards: + enabled: true diff --git a/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml new file mode 100644 index 000000000..55e652c63 --- /dev/null +++ b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml @@ -0,0 +1,16 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + + client_api: + registration_shared_secret: "d233f2fcb0470845a8e150a20ef594ddbe0b4cf7fe482fb9d5120c198557acbf" # echo "dendrite" | sha256sum + +ingress: + enabled: true diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev1.json b/helm/dendrite/grafana_dashboards/dendrite-rev1.json new file mode 100644 index 000000000..206e8af87 --- /dev/null +++ b/helm/dendrite/grafana_dashboards/dendrite-rev1.json @@ -0,0 +1,1119 @@ +{ + "__inputs": [ + { + "name": "DS_INFLUXDB_DOMOTICA", + "label": "", + "description": "", + "type": "datasource", + "pluginId": "influxdb", + "pluginName": "InfluxDB" + }, + { + "name": "DS_PROMETHEUS", + "label": "Prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "7.4.2" + }, + { + "type": "panel", + "id": "graph", + "name": "Graph", + "version": "" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "influxdb", + "name": "InfluxDB", + "version": "1.0.0" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "stat", + "name": "Stat", + "version": "" + } + ], + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Dendrite dashboard from https://github.com/matrix-org/dendrite/", + "editable": true, + "gnetId": 13916, + "graphTooltip": 0, + "id": null, + "iteration": 1613683251329, + "links": [], + "panels": [ + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 4, + "panels": [], + "title": "Overview", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 5, + "w": 10, + "x": 0, + "y": 1 + }, + "hiddenSeries": false, + "id": 2, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "rate(process_cpu_seconds_total{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{job}}-{{index}} ", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "CPU usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "decimals": null, + "format": "percentunit", + "label": null, + "logBase": 1, + "max": "1", + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "datasource": "${DS_PROMETHEUS}", + "description": "Total number of registered users", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 10, + "y": 1 + }, + "id": 20, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "7.4.2", + "targets": [ + { + "exemplar": false, + "expr": "dendrite_clientapi_reg_users_total", + "instant": false, + "interval": "", + "legendFormat": "Users", + "refId": "A" + } + ], + "title": "Registerd Users", + "type": "stat" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "The number of sync requests that are active right now and are waiting to be woken by a notifier", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 5, + "w": 10, + "x": 14, + "y": 1 + }, + "hiddenSeries": false, + "id": 6, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(dendrite_syncapi_active_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", + "hide": false, + "interval": "", + "legendFormat": "active", + "refId": "A" + }, + { + "expr": "sum(rate(dendrite_syncapi_waiting_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", + "hide": false, + "interval": "", + "legendFormat": "waiting", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Sync API", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:232", + "format": "hertz", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "$$hashKey": "object:233", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolateOranges", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "tsbuckets", + "datasource": "${DS_PROMETHEUS}", + "description": "How long it takes to build and submit a new event from the client API to the roomserver", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 24, + "x": 0, + "y": 6 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 24, + "legend": { + "show": false + }, + "pluginVersion": "7.4.2", + "reverseYBuckets": false, + "targets": [ + { + "expr": "dendrite_clientapi_sendevent_duration_millis_bucket{action=\"build\",instance=\"$instance\"}", + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Sendevent Duration", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "s", + "logBase": 1, + "max": null, + "min": "0", + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 11 + }, + "id": 8, + "panels": [], + "title": "Federation", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "Collection of queues for sending transactions to other matrix servers", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 6, + "w": 24, + "x": 0, + "y": 12 + }, + "hiddenSeries": false, + "id": 10, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_federationsender_destination_queues_running", + "interval": "", + "legendFormat": "Queue Running", + "refId": "A" + }, + { + "expr": "dendrite_federationsender_destination_queues_total", + "hide": false, + "interval": "", + "legendFormat": "Queue Total", + "refId": "B" + }, + { + "expr": "dendrite_federationsender_destination_queues_backing_off", + "hide": false, + "interval": "", + "legendFormat": "Backing Off", + "refId": "C" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Federation Sender Destination", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:443", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:444", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 18 + }, + "id": 26, + "panels": [], + "title": "Rooms", + "type": "row" + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolateOranges", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "timeseries", + "datasource": "${DS_PROMETHEUS}", + "description": "How long it takes the roomserver to process an event", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 19 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 28, + "legend": { + "show": false + }, + "pluginVersion": "7.4.2", + "reverseYBuckets": false, + "targets": [ + { + "expr": "sum(rate(dendrite_roomserver_processroomevent_duration_millis_bucket{instance=\"$instance\"}[$bucket_size])) by (le)", + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Room Event Processing", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "s", + "logBase": 1, + "max": null, + "min": null, + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "collapsed": false, + "datasource": "${DS_INFLUXDB_DOMOTICA}", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 26 + }, + "id": 12, + "panels": [], + "title": "Caches", + "type": "row" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 27 + }, + "hiddenSeries": false, + "id": 14, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_server_key", + "interval": "", + "legendFormat": "Server keys", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Server Keys", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:667", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:668", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 27 + }, + "hiddenSeries": false, + "id": 16, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_federation_event", + "interval": "", + "legendFormat": "Federation Event", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Federation Events", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:784", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:785", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 16, + "y": 27 + }, + "hiddenSeries": false, + "id": 18, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.4.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "dendrite_caching_in_memory_lru_roomserver_room_ids", + "interval": "", + "legendFormat": "Room IDs", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Room IDs", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:898", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:899", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "10s", + "schemaVersion": 27, + "style": "dark", + "tags": [ + "matrix", + "dendrite" + ], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "Prometheus", + "value": "Prometheus" + }, + "description": null, + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "datasource", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "auto": true, + "auto_count": 100, + "auto_min": "30s", + "current": { + "selected": false, + "text": "auto", + "value": "$__auto_interval_bucket_size" + }, + "description": null, + "error": null, + "hide": 0, + "label": "Bucket Size", + "name": "bucket_size", + "options": [ + { + "selected": true, + "text": "auto", + "value": "$__auto_interval_bucket_size" + }, + { + "selected": false, + "text": "30s", + "value": "30s" + }, + { + "selected": false, + "text": "1m", + "value": "1m" + }, + { + "selected": false, + "text": "2m", + "value": "2m" + }, + { + "selected": false, + "text": "5m", + "value": "5m" + }, + { + "selected": false, + "text": "10m", + "value": "10m" + }, + { + "selected": false, + "text": "15m", + "value": "15m" + } + ], + "query": "30s,1m,2m,5m,10m,15m", + "queryValue": "", + "refresh": 2, + "skipUrlSync": false, + "type": "interval" + }, + { + "allValue": null, + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", + "description": null, + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "instance", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": null, + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": "Job", + "multi": true, + "name": "job", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 1, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".*", + "current": {}, + "datasource": "${DS_PROMETHEUS}", + "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": true, + "name": "index", + "options": [], + "query": { + "query": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 3, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-3h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Dendrite", + "uid": "RoRt1jEGz", + "version": 8 +} \ No newline at end of file diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl new file mode 100644 index 000000000..026706588 --- /dev/null +++ b/helm/dendrite/templates/_helpers.tpl @@ -0,0 +1,74 @@ +{{- define "validate.config" }} +{{- if not .Values.signing_key.create -}} +{{- fail "You must create a signing key for configuration.signing_key. (see https://github.com/matrix-org/dendrite/blob/master/docs/INSTALL.md#server-key-generation)" -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.host .Values.postgresql.enabled) -}} +{{- fail "Database server must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.user .Values.postgresql.enabled) -}} +{{- fail "Database user must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.password .Values.postgresql.enabled) -}} +{{- fail "Database password must be set." -}} +{{- end -}} +{{- end -}} + + +{{- define "image.name" -}} +{{- with .Values.image -}} +image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }} +imagePullPolicy: {{ .pullPolicy }} +{{- end -}} +{{- end -}} + +{{/* +Expand the name of the chart. +*/}} +{{- define "dendrite.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "dendrite.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "dendrite.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "dendrite.labels" -}} +helm.sh/chart: {{ include "dendrite.chart" . }} +{{ include "dendrite.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "dendrite.selectorLabels" -}} +app.kubernetes.io/name: {{ include "dendrite.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/_overrides.yaml b/helm/dendrite/templates/_overrides.yaml new file mode 100644 index 000000000..edb8ba83a --- /dev/null +++ b/helm/dendrite/templates/_overrides.yaml @@ -0,0 +1,16 @@ +{{- define "override.config" }} +{{- if .Values.postgresql.enabled }} +{{- $_ := set .Values.dendrite_config.global.database "connection_string" (print "postgresql://" .Values.postgresql.auth.username ":" .Values.postgresql.auth.password "@" .Release.Name "-postgresql/dendrite?sslmode=disable") -}} +{{ end }} +global: + private_key: /etc/dendrite/secrets/signing.key + jetstream: + in_memory: false +{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +app_service_api: + config_files: + {{- range $x, $y := .Files.Glob "appservices/*" }} + - /etc/dendrite/appservices/{{ base $x }} + {{ end }} +{{ end }} +{{ end }} diff --git a/helm/dendrite/templates/configmap_grafana_dashboards.yaml b/helm/dendrite/templates/configmap_grafana_dashboards.yaml new file mode 100644 index 000000000..e2abc4909 --- /dev/null +++ b/helm/dendrite/templates/configmap_grafana_dashboards.yaml @@ -0,0 +1,16 @@ +{{- if .Values.grafana.dashboards.enabled }} +{{- range $path, $bytes := .Files.Glob "grafana_dashboards/*" }} +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "dendrite.fullname" $ }}-grafana-dashboards-{{ base $path }} + labels: + {{- include "dendrite.labels" $ | nindent 4 }} + {{- toYaml $.Values.grafana.dashboards.labels | nindent 4 }} + annotations: + {{- toYaml $.Values.grafana.dashboards.annotations | nindent 4 }} +data: + {{- ($.Files.Glob $path ).AsConfig | nindent 2 }} +{{- end }} +{{- end }} diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml new file mode 100644 index 000000000..b463c7d0b --- /dev/null +++ b/helm/dendrite/templates/deployment.yaml @@ -0,0 +1,103 @@ +{{ template "validate.config" . }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "dendrite.selectorLabels" . | nindent 6 }} + replicas: 1 + template: + metadata: + labels: + {{- include "dendrite.selectorLabels" . | nindent 8 }} + annotations: + confighash-global: secret-{{ .Values.global | toYaml | sha256sum | trunc 32 }} + confighash-clientapi: clientapi-{{ .Values.clientapi | toYaml | sha256sum | trunc 32 }} + confighash-federationapi: federationapi-{{ .Values.federationapi | toYaml | sha256sum | trunc 32 }} + confighash-mediaapi: mediaapi-{{ .Values.mediaapi | toYaml | sha256sum | trunc 32 }} + confighash-syncapi: syncapi-{{ .Values.syncapi | toYaml | sha256sum | trunc 32 }} + spec: + volumes: + - name: {{ include "dendrite.fullname" . }}-conf-vol + secret: + secretName: {{ include "dendrite.fullname" . }}-conf + - name: {{ include "dendrite.fullname" . }}-signing-key + secret: + secretName: {{ default (print ( include "dendrite.fullname" . ) "-signing-key") $.Values.signing_key.existingSecret | quote }} + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - name: {{ include "dendrite.fullname" . }}-appservices + secret: + secretName: {{ include "dendrite.fullname" . }}-appservices-conf + {{- end }} + - name: {{ include "dendrite.fullname" . }}-jetstream + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-jetstream-pvc") $.Values.persistence.jetstream.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-media + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-media-pvc") $.Values.persistence.media.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-search + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} + containers: + - name: {{ .Chart.Name }} + {{- include "image.name" . | nindent 8 }} + args: + - '--config' + - '/etc/dendrite/dendrite.yaml' + ports: + - name: http + containerPort: 8008 + protocol: TCP + {{- if $.Values.dendrite_config.global.profiling.enabled }} + env: + - name: PPROFLISTEN + value: "localhost:{{- $.Values.global.profiling.port -}}" + {{- end }} + resources: + {{- toYaml $.Values.resources | nindent 10 }} + volumeMounts: + - mountPath: /etc/dendrite/ + name: {{ include "dendrite.fullname" . }}-conf-vol + - mountPath: /etc/dendrite/secrets/ + name: {{ include "dendrite.fullname" . }}-signing-key + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - mountPath: /etc/dendrite/appservices + name: {{ include "dendrite.fullname" . }}-appservices + readOnly: true + {{ end }} + - mountPath: {{ .Values.dendrite_config.media_api.base_path }} + name: {{ include "dendrite.fullname" . }}-media + - mountPath: {{ .Values.dendrite_config.global.jetstream.storage_path }} + name: {{ include "dendrite.fullname" . }}-jetstream + - mountPath: {{ .Values.dendrite_config.sync_api.search.index_path }} + name: {{ include "dendrite.fullname" . }}-search + livenessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + readinessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + startupProbe: + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/up + port: http \ No newline at end of file diff --git a/helm/dendrite/templates/ingress.yaml b/helm/dendrite/templates/ingress.yaml new file mode 100644 index 000000000..8f86ad723 --- /dev/null +++ b/helm/dendrite/templates/ingress.yaml @@ -0,0 +1,55 @@ +{{- if .Values.ingress.enabled -}} + {{- $fullName := include "dendrite.fullname" . -}} + {{- $svcPort := .Values.service.port -}} + {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} + {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} + {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} + {{- end }} + {{- end }} + {{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1 + {{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1beta1 + {{- else -}} +apiVersion: extensions/v1beta1 + {{- end }} +kind: Ingress +metadata: + name: {{ $fullName }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + annotations: + {{- with .Values.ingress.annotations }} + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: + {{- range .Values.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} + rules: + - host: {{ .Values.ingress.hostName | quote }} + http: + paths: + - path: / + pathType: ImplementationSpecific + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + number: {{ $svcPort }} + {{- else }} + serviceName: {{ $fullName }} + servicePort: {{ $svcPort }} + {{- end }} + {{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml new file mode 100644 index 000000000..c10f358b0 --- /dev/null +++ b/helm/dendrite/templates/jobs.yaml @@ -0,0 +1,100 @@ +{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret ) }} +{{ $name := (print ( include "dendrite.fullname" . ) "-signing-key") }} +{{ $secretName := (print ( include "dendrite.fullname" . ) "-signing-key") }} +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +rules: + - apiGroups: + - "" + resources: + - secrets + resourceNames: + - {{ $secretName }} + verbs: + - get + - update + - patch +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ $name }} +subjects: + - kind: ServiceAccount + name: {{ $name }} + namespace: {{ .Release.Namespace }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: generate-signing-key + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + template: + spec: + restartPolicy: "Never" + serviceAccount: {{ $name }} + containers: + - name: upload-key + image: bitnami/kubectl + command: + - sh + - -c + - | + # check if key already exists + key=$(kubectl get secret {{ $secretName }} -o jsonpath="{.data['signing\.key']}" 2> /dev/null) + [ $? -ne 0 ] && echo "Failed to get existing secret" && exit 1 + [ -n "$key" ] && echo "Key already created, exiting." && exit 0 + # wait for signing key + while [ ! -f /etc/dendrite/signing-key.pem ]; do + echo "Waiting for signing key.." + sleep 5; + done + # update secret + kubectl patch secret {{ $secretName }} -p "{\"data\":{\"signing.key\":\"$(base64 /etc/dendrite/signing-key.pem | tr -d '\n')\"}}" + [ $? -ne 0 ] && echo "Failed to update secret." && exit 1 + echo "Signing key successfully created." + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + readOnly: true + - name: generate-key + {{- include "image.name" . | nindent 8 }} + command: + - sh + - -c + - | + /usr/bin/generate-keys -private-key /etc/dendrite/signing-key.pem + chown 1001:1001 /etc/dendrite/signing-key.pem + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + volumes: + - name: signing-key + emptyDir: {} + parallelism: 1 + completions: 1 + backoffLimit: 1 +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/prometheus-rules.yaml b/helm/dendrite/templates/prometheus-rules.yaml new file mode 100644 index 000000000..6693a4ed9 --- /dev/null +++ b/helm/dendrite/templates/prometheus-rules.yaml @@ -0,0 +1,16 @@ +{{- if and ( .Values.prometheus.rules.enabled ) ( .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" ) }} +--- +apiVersion: monitoring.coreos.com/v1 +kind: PrometheusRule +metadata: + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + {{- toYaml .Values.prometheus.rules.labels | nindent 4 }} +spec: + groups: + {{- if .Values.prometheus.rules.additionalRules }} + - name: {{ template "dendrite.name" . }}-Additional + rules: {{- toYaml .Values.prometheus.rules.additionalRules | nindent 4 }} + {{- end }} +{{- end }} diff --git a/helm/dendrite/templates/pvc.yaml b/helm/dendrite/templates/pvc.yaml new file mode 100644 index 000000000..897957e60 --- /dev/null +++ b/helm/dendrite/templates/pvc.yaml @@ -0,0 +1,48 @@ +{{ if not .Values.persistence.media.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-media-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.media.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.jetstream.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-jetstream-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.jetstream.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.search.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-search-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.search.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/secrets.yaml b/helm/dendrite/templates/secrets.yaml new file mode 100644 index 000000000..2084c9a56 --- /dev/null +++ b/helm/dendrite/templates/secrets.yaml @@ -0,0 +1,45 @@ +{{- if (gt (len (.Files.Glob "appservices/*")) 0) }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" . }}-appservices-conf +type: Opaque +data: +{{ (.Files.Glob "appservices/*").AsSecrets | indent 2 }} +{{- end }} + +{{- if and .Values.signing_key.create (not .Values.signing_key.existingSecret) }} +--- +apiVersion: v1 +kind: Secret +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-signing-key +type: Opaque +{{- end }} + +{{- with .Values.dendrite_config.global.metrics }} +{{- if .enabled }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" $ }}-metrics-basic-auth +type: Opaque +stringData: + user: {{ .basic_auth.user | quote }} + password: {{ .basic_auth.password | quote }} +{{- end }} +{{- end }} + +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" . }}-conf +type: Opaque +stringData: + dendrite.yaml: | + {{ toYaml ( mustMergeOverwrite .Values.dendrite_config ( fromYaml (include "override.config" .) ) .Values.dendrite_config ) | nindent 4 }} \ No newline at end of file diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml new file mode 100644 index 000000000..3b571df1f --- /dev/null +++ b/helm/dendrite/templates/service.yaml @@ -0,0 +1,17 @@ +{{ template "validate.config" . }} +--- +apiVersion: v1 +kind: Service +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + ports: + - name: http + protocol: TCP + port: {{ .Values.service.port }} + targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/templates/servicemonitor.yaml b/helm/dendrite/templates/servicemonitor.yaml new file mode 100644 index 000000000..3819c7d02 --- /dev/null +++ b/helm/dendrite/templates/servicemonitor.yaml @@ -0,0 +1,26 @@ +{{- if and + (and .Values.prometheus.servicemonitor.enabled .Values.dendrite_config.global.metrics.enabled ) + ( .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" ) +}} +--- +apiVersion: monitoring.coreos.com/v1 +kind: ServiceMonitor +metadata: + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + {{- toYaml .Values.prometheus.servicemonitor.labels | nindent 4 }} +spec: + endpoints: + - port: http + basicAuth: + username: + name: {{ include "dendrite.fullname" . }}-metrics-basic-auth + key: "user" + password: + name: {{ include "dendrite.fullname" . }}-metrics-basic-auth + key: "password" + selector: + matchLabels: + {{- include "dendrite.selectorLabels" . | nindent 6 }} +{{- end }} diff --git a/helm/dendrite/templates/tests/test-version.yaml b/helm/dendrite/templates/tests/test-version.yaml new file mode 100644 index 000000000..d88751325 --- /dev/null +++ b/helm/dendrite/templates/tests/test-version.yaml @@ -0,0 +1,17 @@ +--- +apiVersion: v1 +kind: Pod +metadata: + name: "{{ include "dendrite.fullname" . }}-test-version" + labels: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + annotations: + "helm.sh/hook": test +spec: + containers: + - name: curl + image: curlimages/curl + imagePullPolicy: IfNotPresent + args: + - 'http://{{- include "dendrite.fullname" . -}}:8008/_matrix/client/versions' + restartPolicy: Never diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml new file mode 100644 index 000000000..c219d27f8 --- /dev/null +++ b/helm/dendrite/values.yaml @@ -0,0 +1,373 @@ +image: + # -- Docker repository/image to use + repository: "ghcr.io/matrix-org/dendrite-monolith" + # -- Kubernetes pullPolicy + pullPolicy: IfNotPresent + # -- Overrides the image tag whose default is the chart appVersion. + tag: "" + + +# signing key to use +signing_key: + # -- Create a new signing key, if not exists + create: true + # -- Use an existing secret + existingSecret: "" + +# -- Default resource requests/limits. +# @default -- sets some sane default values +resources: + requests: + memory: "512Mi" + + limits: + memory: "4096Mi" + +persistence: + # -- The storage class to use for volume claims. Defaults to the + # cluster default storage class. + storageClass: "" + jetstream: + # -- Use an existing volume claim for jetstream + existingClaim: "" + # -- PVC Storage Request for the jetstream volume + capacity: "1Gi" + media: + # -- Use an existing volume claim for media files + existingClaim: "" + # -- PVC Storage Request for the media volume + capacity: "1Gi" + search: + # -- Use an existing volume claim for the fulltext search index + existingClaim: "" + # -- PVC Storage Request for the search volume + capacity: "1Gi" + +dendrite_config: + version: 2 + global: + # -- **REQUIRED** Servername for this Dendrite deployment. + server_name: "" + + # -- The private key to use. (**NOTE**: This is overriden in Helm) + private_key: /etc/dendrite/secrets/signing.key + + # -- The server name to delegate server-server communications to, with optional port + # e.g. localhost:443 + well_known_server_name: "" + + # -- The server name to delegate client-server communications to, with optional port + # e.g. localhost:443 + well_known_client_name: "" + + # -- Lists of domains that the server will trust as identity servers to verify third + # party identifiers such as phone numbers and email addresses. + trusted_third_party_id_servers: + - matrix.org + - vector.im + + # -- The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) + # to old signing keys that were formerly in use on this domain name. These + # keys will not be used for federation request or event signing, but will be + # provided to any other homeserver that asks when trying to verify old events. + old_private_keys: + # If the old private key file is available: + # - private_key: old_matrix_key.pem + # expired_at: 1601024554498 + # If only the public key (in base64 format) and key ID are known: + # - public_key: mn59Kxfdq9VziYHSBzI7+EDPDcBS2Xl7jeUdiiQcOnM= + # key_id: ed25519:mykeyid + # expired_at: 1601024554498 + + # -- Disable federation. Dendrite will not be able to make any outbound HTTP requests + # to other servers and the federation API will not be exposed. + disable_federation: false + + key_validity_period: 168h0m0s + + database: + # -- The connection string for connections to Postgres. + # This will be set automatically if using the Postgres dependency + connection_string: "" + + # -- Default database maximum open connections + max_open_conns: 90 + # -- Default database maximum idle connections + max_idle_conns: 5 + # -- Default database maximum lifetime + conn_max_lifetime: -1 + + jetstream: + # -- Persistent directory to store JetStream streams in. + storage_path: "/data/jetstream" + # -- NATS JetStream server addresses if not using internal NATS. + addresses: [] + # -- The prefix for JetStream streams + topic_prefix: "Dendrite" + # -- Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) + in_memory: false + # -- Disables TLS validation. This should **NOT** be used in production. + disable_tls_validation: true + + cache: + # -- The estimated maximum size for the global cache in bytes, or in terabytes, + # gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or + # 'kb' suffix is specified. Note that this is not a hard limit, nor is it a + # memory limit for the entire process. A cache that is too small may ultimately + # provide little or no benefit. + max_size_estimated: 1gb + # -- The maximum amount of time that a cache entry can live for in memory before + # it will be evicted and/or refreshed from the database. Lower values result in + # easier admission of new cache entries but may also increase database load in + # comparison to higher values, so adjust conservatively. Higher values may make + # it harder for new items to make it into the cache, e.g. if new rooms suddenly + # become popular. + max_age: 1h + + report_stats: + # -- Configures phone-home statistics reporting. These statistics contain the server + # name, number of active users and some information on your deployment config. + # We use this information to understand how Dendrite is being used in the wild. + enabled: false + # -- Endpoint to report statistics to. + endpoint: https://matrix.org/report-usage-stats/push + + presence: + # -- Controls whether we receive presence events from other servers + enable_inbound: false + # -- Controls whether we send presence events for our local users to other servers. + # (_May increase CPU/memory usage_) + enable_outbound: false + + server_notices: + # -- Server notices allows server admins to send messages to all users on the server. + enabled: false + # -- The local part for the user sending server notices. + local_part: "_server" + # -- The display name for the user sending server notices. + display_name: "Server Alerts" + # -- The avatar URL (as a mxc:// URL) name for the user sending server notices. + avatar_url: "" + # The room name to be used when sending server notices. This room name will + # appear in user clients. + room_name: "Server Alerts" + + # prometheus metrics + metrics: + # -- Whether or not Prometheus metrics are enabled. + enabled: false + # HTTP basic authentication to protect access to monitoring. + basic_auth: + # -- HTTP basic authentication username + user: "metrics" + # -- HTTP basic authentication password + password: metrics + + dns_cache: + # -- Whether or not the DNS cache is enabled. + enabled: false + # -- Maximum number of entries to hold in the DNS cache + cache_size: 256 + # -- Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + cache_lifetime: "10m" + + profiling: + # -- Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, + # as it will only listen on localhost and the defined port. + # e.g. `kubectl port-forward deployments/dendrite 65432:65432` + enabled: false + # -- pprof port, if enabled + port: 65432 + + # -- Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) + mscs: + mscs: + - msc2946 + # A list of enabled MSC's + # Currently valid values are: + # - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) + + + app_service_api: + # -- Disable the validation of TLS certificates of appservices. This is + # not recommended in production since it may allow appservice traffic + # to be sent to an insecure endpoint. + disable_tls_validation: false + # -- Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) + config_files: [] + + client_api: + # -- Prevents new users from being able to register on this homeserver, except when + # using the registration shared secret below. + registration_disabled: true + + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + + # -- If set, allows registration by anyone who knows the shared secret, regardless of + # whether registration is otherwise disabled. + registration_shared_secret: "" + + # -- enable reCAPTCHA registration + enable_registration_captcha: false + # -- reCAPTCHA public key + recaptcha_public_key: "" + # -- reCAPTCHA private key + recaptcha_private_key: "" + # -- reCAPTCHA bypass secret + recaptcha_bypass_secret: "" + recaptcha_siteverify_api: "" + + # TURN server information that this homeserver should send to clients. + turn: + # -- Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + turn_user_lifetime: "24h" + turn_uris: [] + turn_shared_secret: "" + # -- The TURN username + turn_username: "" + # -- The TURN password + turn_password: "" + + rate_limiting: + # -- Enable rate limiting + enabled: true + # -- After how many requests a rate limit should be activated + threshold: 20 + # -- Cooloff time in milliseconds + cooloff_ms: 500 + # -- Users which should be exempt from rate limiting + exempt_user_ids: + + federation_api: + # -- Federation failure threshold. How many consecutive failures that we should + # tolerate when sending federation requests to a specific server. The backoff + # is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. + # The default value is 16 if not specified, which is circa 18 hours. + send_max_retries: 16 + # -- Disable TLS validation. This should **NOT** be used in production. + disable_tls_validation: false + prefer_direct_fetch: false + # -- Prevents Dendrite from keeping HTTP connections + # open for reuse for future requests. Connections will be closed quicker + # but we may spend more time on TLS handshakes instead. + disable_http_keepalives: false + # -- Perspective keyservers, to use as a backup when direct key fetch + # requests don't succeed. + # @default -- See value.yaml + key_perspectives: + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + + media_api: + # -- The path to store media files (e.g. avatars) in + base_path: "/data/media_store" + # -- The max file size for uploaded media files + max_file_size_bytes: 10485760 + # Whether to dynamically generate thumbnails if needed. + dynamic_thumbnails: false + # -- The maximum number of simultaneous thumbnail generators to run. + max_thumbnail_generators: 10 + # -- A list of thumbnail sizes to be generated for media content. + # @default -- See value.yaml + thumbnail_sizes: + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale + + sync_api: + # -- This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + real_ip_header: X-Real-IP + # -- Configuration for the full-text search engine. + search: + # -- Whether fulltext search is enabled. + enabled: true + # -- The path to store the search index in. + index_path: "/data/search" + # -- The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) + language: "en" + + user_api: + # -- bcrypt cost to use when hashing passwords. + # (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) + bcrypt_cost: 10 + # -- OpenID Token lifetime in milliseconds. + openid_token_lifetime_ms: 3600000 + # - Disable TLS validation when hitting push gateways. This should **NOT** be used in production. + push_gateway_disable_tls_validation: false + # -- Rooms to join users to after registration + auto_join_rooms: [] + + # -- Default logging configuration + logging: + - type: std + level: info + +postgresql: + # -- Enable and configure postgres as the database for dendrite. + # @default -- See value.yaml + enabled: false + image: + repository: bitnami/postgresql + tag: "15.1.0" + auth: + username: dendrite + password: changeme + database: dendrite + + persistence: + enabled: false + +ingress: + # -- Create an ingress for a monolith deployment + enabled: false + hosts: [] + className: "" + hostName: "" + # -- Extra, custom annotations + annotations: {} + + tls: [] + +service: + type: ClusterIP + port: 8008 + +prometheus: + servicemonitor: + # -- Enable ServiceMonitor for Prometheus-Operator for scrape metric-endpoint + enabled: false + # -- Extra Labels on ServiceMonitor for selector of Prometheus Instance + labels: {} + rules: + # -- Enable PrometheusRules for Prometheus-Operator for setup alerting + enabled: false + # -- Extra Labels on PrometheusRules for selector of Prometheus Instance + labels: {} + # -- additional alertrules (no default alertrules are provided) + additionalRules: [] + +grafana: + dashboards: + enabled: false + # -- Extra Labels on ConfigMap for selector of grafana sidecar + labels: + grafana_dashboard: "1" + # -- Extra Annotations on ConfigMap additional config in grafana sidecar + annotations: {} diff --git a/internal/caching/cache_eventstatekeys.go b/internal/caching/cache_eventstatekeys.go index 05580ab05..51e2499d5 100644 --- a/internal/caching/cache_eventstatekeys.go +++ b/internal/caching/cache_eventstatekeys.go @@ -7,6 +7,7 @@ import "github.com/matrix-org/dendrite/roomserver/types" type EventStateKeyCache interface { GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) + GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool) } func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) { @@ -15,4 +16,23 @@ func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (strin func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) { c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey) + c.RoomServerStateKeyNIDs.Set(eventStateKey, eventStateKeyNID) +} + +func (c Caches) GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool) { + return c.RoomServerStateKeyNIDs.Get(eventStateKey) +} + +type EventTypeCache interface { + GetEventTypeKey(eventType string) (types.EventTypeNID, bool) + StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string) +} + +func (c Caches) StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string) { + c.RoomServerEventTypeNIDs.Set(eventType, eventTypeNID) + c.RoomServerEventTypes.Set(eventTypeNID, eventType) +} + +func (c Caches) GetEventTypeKey(eventType string) (types.EventTypeNID, bool) { + return c.RoomServerEventTypeNIDs.Get(eventType) } diff --git a/internal/caching/cache_roomevents.go b/internal/caching/cache_roomevents.go index 9d5d3b912..14b6c3af8 100644 --- a/internal/caching/cache_roomevents.go +++ b/internal/caching/cache_roomevents.go @@ -10,6 +10,7 @@ import ( type RoomServerEventsCache interface { GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) + InvalidateRoomServerEvent(eventNID types.EventNID) } func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) { @@ -19,3 +20,7 @@ func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib. func (c Caches) StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) { c.RoomServerEvents.Set(int64(eventNID), event) } + +func (c Caches) InvalidateRoomServerEvent(eventNID types.EventNID) { + c.RoomServerEvents.Unset(int64(eventNID)) +} diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 88a5b28bc..734a3a04f 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -9,19 +9,28 @@ type RoomServerCaches interface { RoomVersionCache RoomServerEventsCache EventStateKeyCache + EventTypeCache } // RoomServerNIDsCache contains the subset of functions needed for // a roomserver NID cache. type RoomServerNIDsCache interface { GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) + // StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) + GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) } func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { return c.RoomServerRoomIDs.Get(roomNID) } +// StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { + c.RoomServerRoomNIDs.Set(roomID, roomNID) c.RoomServerRoomIDs.Set(roomNID, roomID) } + +func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) { + return c.RoomServerRoomNIDs.Get(roomID) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 78c9ab7ee..479920466 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -23,16 +23,19 @@ import ( // different implementations as long as they satisfy the Cache // interface. type Caches struct { - RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version - ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys - RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID - RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID - RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event - RoomServerStateKeys Cache[types.EventStateKeyNID, string] // event NID -> event state key - FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU - FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU - SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response - LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID + RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version + ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys + RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID + RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID + RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event + RoomServerStateKeys Cache[types.EventStateKeyNID, string] // eventStateKey NID -> event state key + RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID + RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID + RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType + FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU + FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU + SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response + LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index 49292d0dc..106b9c99f 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -22,11 +22,12 @@ import ( "github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto/z" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" ) const ( @@ -40,6 +41,9 @@ const ( spaceSummaryRoomsCache lazyLoadingCache eventStateKeyCache + eventTypeCache + eventTypeNIDCache + eventStateKeyNIDCache ) func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches { @@ -95,9 +99,10 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm }, RoomServerEvents: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.Event]{ // event NID -> event &RistrettoCachePartition[int64, *gomatrixserverlib.Event]{ - cache: cache, - Prefix: roomEventsCache, - MaxAge: maxAge, + cache: cache, + Prefix: roomEventsCache, + MaxAge: maxAge, + Mutable: true, }, }, RoomServerStateKeys: &RistrettoCachePartition[types.EventStateKeyNID, string]{ // event NID -> event state key @@ -105,6 +110,21 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm Prefix: eventStateKeyCache, MaxAge: maxAge, }, + RoomServerStateKeyNIDs: &RistrettoCachePartition[string, types.EventStateKeyNID]{ // eventStateKey -> eventStateKey NID + cache: cache, + Prefix: eventStateKeyNIDCache, + MaxAge: maxAge, + }, + RoomServerEventTypeNIDs: &RistrettoCachePartition[string, types.EventTypeNID]{ // eventType -> eventType NID + cache: cache, + Prefix: eventTypeCache, + MaxAge: maxAge, + }, + RoomServerEventTypes: &RistrettoCachePartition[types.EventTypeNID, string]{ // eventType NID -> eventType + cache: cache, + Prefix: eventTypeNIDCache, + MaxAge: maxAge, + }, FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU &RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ cache: cache, diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 223282a25..d6c79e989 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -16,7 +16,9 @@ // Hooks can only be run in monolith mode. package hooks -import "sync" +import ( + "sync" +) const ( // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent diff --git a/internal/httputil/http.go b/internal/httputil/http.go deleted file mode 100644 index ad26de512..000000000 --- a/internal/httputil/http.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httputil - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" -) - -// PostJSON performs a POST request with JSON on an internal HTTP API. -// The error will match the errtype if returned from the remote API, or -// will be a different type if there was a problem reaching the API. -func PostJSON[reqtype, restype any, errtype error]( - ctx context.Context, span opentracing.Span, httpClient *http.Client, - apiURL string, request *reqtype, response *restype, -) error { - jsonBytes, err := json.Marshal(request) - if err != nil { - return err - } - - parsedAPIURL, err := url.Parse(apiURL) - if err != nil { - return err - } - - parsedAPIURL.Path = InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") - apiURL = parsedAPIURL.String() - - req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes)) - if err != nil { - return err - } - - // Mark the span as being an RPC client. - ext.SpanKindRPCClient.Set(span) - carrier := opentracing.HTTPHeadersCarrier(req.Header) - tracer := opentracing.GlobalTracer() - - if err = tracer.Inject(span.Context(), opentracing.HTTPHeaders, carrier); err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - - res, err := httpClient.Do(req.WithContext(ctx)) - if res != nil { - defer (func() { err = res.Body.Close() })() - } - if err != nil { - return err - } - var body []byte - body, err = io.ReadAll(res.Body) - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - if len(body) == 0 { - return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL) - } - var reserr errtype - if err = json.Unmarshal(body, &reserr); err != nil { - return fmt.Errorf("HTTP %d from %s - %w", res.StatusCode, apiURL, err) - } - return reserr - } - if err = json.Unmarshal(body, response); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - return nil -} diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 85ebf6176..f7e739a87 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -24,16 +24,16 @@ import ( "strings" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // BasicAuth is used for authorization on /metrics handlers @@ -236,9 +236,9 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse } } - span := opentracing.StartSpan(metricsName) - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) + trace, ctx := internal.StartTask(req.Context(), metricsName) + defer trace.EndTask() + req = req.WithContext(ctx) if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { ips := strings.Split(forwardedFor, ", ") req.RemoteAddr = ips[0] @@ -252,17 +252,16 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { - span := opentracing.StartSpan(metricsName) - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - if err := f(w, req); err != nil { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return *err - })) - h.ServeHTTP(w, req) - } + trace, ctx := internal.StartTask(req.Context(), metricsName) + defer trace.EndTask() + req = req.WithContext(ctx) + f(w, req) + } + + if !enableMetrics { + return http.HandlerFunc(withSpan) } return promhttp.InstrumentHandlerCounter( @@ -278,53 +277,6 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) ) } -// MakeInternalAPI turns a util.JSONRequestHandler function into an http.Handler. -// This is used for APIs that are internal to dendrite. -// If we are passed a tracing context in the request headers then we use that -// as the parent of any tracing spans we create. -func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(f)) - withSpan := func(w http.ResponseWriter, req *http.Request) { - carrier := opentracing.HTTPHeadersCarrier(req.Header) - tracer := opentracing.GlobalTracer() - clientContext, err := tracer.Extract(opentracing.HTTPHeaders, carrier) - var span opentracing.Span - if err == nil { - // Default to a span without RPC context. - span = tracer.StartSpan(metricsName) - } else { - // Set the RPC context. - span = tracer.StartSpan(metricsName, ext.RPCServerOption(clientContext)) - } - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - h.ServeHTTP(w, req) - } - - return promhttp.InstrumentHandlerCounter( - promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: metricsName + "_requests_total", - Help: "Total number of internal API calls", - Namespace: "dendrite", - }, - []string{"code"}, - ), - promhttp.InstrumentHandlerResponseSize( - promauto.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "dendrite", - Name: metricsName + "_response_size_bytes", - Help: "A histogram of response sizes for requests.", - Buckets: []float64{200, 500, 900, 1500, 5000, 15000, 50000, 100000}, - }, - []string{}, - ), - http.HandlerFunc(withSpan), - ), - ) -} - // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics func WrapHandlerInBasicAuth(h http.Handler, b BasicAuth) http.HandlerFunc { if b.Username == "" || b.Password == "" { diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go deleted file mode 100644 index 385092d9c..000000000 --- a/internal/httputil/internalapi.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httputil - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "reflect" - - "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" -) - -type InternalAPIError struct { - Type string - Message string -} - -func (e InternalAPIError) Error() string { - return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) -} - -func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { - var request reqtype - var response restype - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := f(req.Context(), &request, &response); err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: &InternalAPIError{ - Type: reflect.TypeOf(err).String(), - Message: fmt.Sprintf("%s", err), - }, - } - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: &response, - } - }) -} - -func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { - var request reqtype - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - response, err := f(req.Context(), &request) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: err, - } - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: response, - } - }) -} - -func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error { - span, ctx := opentracing.StartSpanFromContext(ctx, name) - defer span.Finish() - - return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response) -} - -func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, name) - defer span.Finish() - - var response restype - return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response) -} diff --git a/internal/httputil/paths.go b/internal/httputil/paths.go index 12cf59eb4..d06875428 100644 --- a/internal/httputil/paths.go +++ b/internal/httputil/paths.go @@ -19,8 +19,8 @@ const ( PublicFederationPathPrefix = "/_matrix/federation/" PublicKeyPathPrefix = "/_matrix/key/" PublicMediaPathPrefix = "/_matrix/media/" + PublicStaticPath = "/_matrix/static/" PublicWellKnownPrefix = "/.well-known/matrix/" - InternalPathPrefix = "/api/" DendriteAdminPathPrefix = "/_dendrite/" SynapseAdminPathPrefix = "/_synapse/" ) diff --git a/internal/log.go b/internal/log.go index a171555ab..8fe98f20c 100644 --- a/internal/log.go +++ b/internal/log.go @@ -24,6 +24,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "github.com/matrix-org/util" @@ -33,6 +34,12 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) +// logrus is using a global variable when we're using `logrus.AddHook` +// this unfortunately results in us adding the same hook multiple times. +// This map ensures we only ever add one level hook. +var stdLevelLogAdded = make(map[logrus.Level]bool) +var levelLogAddedMu = &sync.Mutex{} + type utcFormatter struct { logrus.Formatter } @@ -94,6 +101,8 @@ func SetupPprof() { // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ @@ -120,9 +129,9 @@ func checkFileHookParams(params map[string]interface{}) { } // Add a new FSHook to the logger. Each component will log in its own file -func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName string) { +func setupFileHook(hook config.LogrusHook, level logrus.Level) { dirPath := (hook.Params["path"]).(string) - fullPath := filepath.Join(dirPath, componentName+".log") + fullPath := filepath.Join(dirPath, "dendrite.log") if err := os.MkdirAll(path.Dir(fullPath), os.ModePerm); err != nil { logrus.Fatalf("Couldn't create directory %s: %q", path.Dir(fullPath), err) diff --git a/internal/log_unix.go b/internal/log_unix.go index 5e8dcaad6..32e04a0e3 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -18,8 +18,10 @@ package internal import ( + "io" "log/syslog" + "github.com/MFAshby/stdemuxerhook" "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" lSyslog "github.com/sirupsen/logrus/hooks/syslog" @@ -28,7 +30,9 @@ import ( // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error -func SetupHookLogging(hooks []config.LogrusHook, componentName string) { +func SetupHookLogging(hooks []config.LogrusHook) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -39,15 +43,19 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { switch hook.Type { case "file": checkFileHookParams(hook.Params) - setupFileHook(hook, level, componentName) + setupFileHook(hook, level) case "syslog": checkSyslogHookParams(hook.Params) - setupSyslogHook(hook, level, componentName) + setupSyslogHook(hook, level) case "std": + setupStdLogHook(level) default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } } + setupStdLogHook(logrus.InfoLevel) + // Hooks are now configured for stdout/err, so throw away the default logger output + logrus.SetOutput(io.Discard) } func checkSyslogHookParams(params map[string]interface{}) { @@ -71,8 +79,16 @@ func checkSyslogHookParams(params map[string]interface{}) { } -func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { - syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, componentName) +func setupStdLogHook(level logrus.Level) { + if stdLevelLogAdded[level] { + return + } + logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())}) + stdLevelLogAdded[level] = true +} + +func setupSyslogHook(hook config.LogrusHook, level logrus.Level) { + syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, "dendrite") if err == nil { logrus.AddHook(&logLevelHook{level, syslogHook}) } diff --git a/internal/log_windows.go b/internal/log_windows.go index 39562328c..e1f0098a1 100644 --- a/internal/log_windows.go +++ b/internal/log_windows.go @@ -22,7 +22,7 @@ import ( // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error -func SetupHookLogging(hooks []config.LogrusHook, componentName string) { +func SetupHookLogging(hooks []config.LogrusHook) { logrus.SetReportCaller(true) for _, hook := range hooks { // Check we received a proper logging level @@ -40,7 +40,7 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { switch hook.Type { case "file": checkFileHookParams(hook.Params) - setupFileHook(hook, level, componentName) + setupFileHook(hook, level) default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go index 95f5afd90..d5671be3b 100644 --- a/internal/pushgateway/client.go +++ b/internal/pushgateway/client.go @@ -9,7 +9,7 @@ import ( "net/http" "time" - "github.com/opentracing/opentracing-go" + "github.com/matrix-org/dendrite/internal" ) type httpClient struct { @@ -32,8 +32,8 @@ func NewHTTPClient(disableTLSValidation bool) Client { } func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "Notify") + defer trace.EndRegion() body, err := json.Marshal(req) if err != nil { @@ -50,8 +50,7 @@ func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, return err } - //nolint:errcheck - defer hresp.Body.Close() + defer internal.CloseAndLogIfError(ctx, hresp.Body, "failed to close response body") if hresp.StatusCode == http.StatusOK { return json.NewDecoder(hresp.Body).Decode(resp) diff --git a/internal/pushgateway/client_test.go b/internal/pushgateway/client_test.go new file mode 100644 index 000000000..bd0dca470 --- /dev/null +++ b/internal/pushgateway/client_test.go @@ -0,0 +1,54 @@ +package pushgateway + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestNotify(t *testing.T) { + wantResponse := NotifyResponse{ + Rejected: []string{"testing"}, + } + + var i = 0 + + svr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // /notify only accepts POST requests + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusNotImplemented) + return + } + + if i != 0 { // error path + w.WriteHeader(http.StatusBadRequest) + return + } + + // happy path + json.NewEncoder(w).Encode(wantResponse) + })) + defer svr.Close() + + cl := NewHTTPClient(true) + gotResponse := NotifyResponse{} + + // Test happy path + err := cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err != nil { + t.Errorf("failed to notify client") + } + if !reflect.DeepEqual(gotResponse, wantResponse) { + t.Errorf("expected response %+v, got %+v", wantResponse, gotResponse) + } + + // Test error path + i++ + err = cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err == nil { + t.Errorf("expected notifying the pushgateway to fail, but it succeeded") + } +} diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go index 2d9773c0f..c7b30da8e 100644 --- a/internal/pushrules/condition.go +++ b/internal/pushrules/condition.go @@ -14,7 +14,7 @@ type Condition struct { // Pattern indicates the value pattern that must match. Required // for EventMatchCondition. - Pattern string `json:"pattern,omitempty"` + Pattern *string `json:"pattern,omitempty"` // Is indicates the condition that must be fulfilled. Required for // RoomMemberCountCondition. diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go index 8982dd587..a055ba03c 100644 --- a/internal/pushrules/default_content.go +++ b/internal/pushrules/default_content.go @@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { RuleID: MRuleContainsUserName, Default: true, Enabled: true, - Pattern: localpart, - Conditions: []*Condition{ - { - Kind: EventMatchCondition, - Key: "content.body", - }, - }, + Pattern: &localpart, Actions: []*Action{ {Kind: NotifyAction}, { @@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index a9788df2f..f97427b71 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -22,15 +22,15 @@ const ( MRuleTombstone = ".m.rule.tombstone" MRuleRoomNotif = ".m.rule.roomnotif" MRuleReaction = ".m.rule.reaction" + MRuleRoomACLs = ".m.rule.room.server_acl" ) var ( mRuleMasterDefinition = Rule{ - RuleID: MRuleMaster, - Default: true, - Enabled: false, - Conditions: []*Condition{}, - Actions: []*Action{{Kind: DontNotifyAction}}, + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Actions: []*Action{{Kind: DontNotifyAction}}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -40,7 +40,7 @@ var ( { Kind: EventMatchCondition, Key: "content.msgtype", - Pattern: "m.notice", + Pattern: pointer("m.notice"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -53,7 +53,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -73,7 +73,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -85,12 +84,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.tombstone", + Pattern: pointer("m.room.tombstone"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{ @@ -98,10 +97,27 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } + mRuleACLsDefinition = Rule{ + RuleID: MRuleRoomACLs, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: pointer("m.room.server_acl"), + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: pointer(""), + }, + }, + Actions: []*Action{}, + } mRuleRoomNotifDefinition = Rule{ RuleID: MRuleRoomNotif, Default: true, @@ -110,7 +126,7 @@ var ( { Kind: EventMatchCondition, Key: "content.body", - Pattern: "@room", + Pattern: pointer("@room"), }, { Kind: SenderNotificationPermissionCondition, @@ -122,7 +138,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -134,7 +149,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.reaction", + Pattern: pointer("m.reaction"), }, }, Actions: []*Action{ @@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule { { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, { Kind: EventMatchCondition, Key: "content.membership", - Pattern: "invite", + Pattern: pointer("invite"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: userID, + Pattern: pointer(userID), }, }, Actions: []*Action{ @@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule { Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } } diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go new file mode 100644 index 000000000..dea829842 --- /dev/null +++ b/internal/pushrules/default_pushrules_test.go @@ -0,0 +1,111 @@ +package pushrules + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Tests that the pre-defined rules as of +// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules +// are correct +func TestDefaultRules(t *testing.T) { + type testCase struct { + name string + inputBytes []byte + want Rule + } + + testCases := []testCase{ + // Default override rules + { + name: ".m.rule.master", + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), + want: mRuleMasterDefinition, + }, + { + name: ".m.rule.suppress_notices", + inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`), + want: mRuleSuppressNoticesDefinition, + }, + { + name: ".m.rule.invite_for_me", + inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: *mRuleInviteForMeDefinition("@test:localhost"), + }, + { + name: ".m.rule.member_event", + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + want: mRuleMemberEventDefinition, + }, + { + name: ".m.rule.contains_display_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`), + want: mRuleContainsDisplayNameDefinition, + }, + { + name: ".m.rule.tombstone", + inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleTombstoneDefinition, + }, + { + name: ".m.rule.room.server_acl", + inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`), + want: mRuleACLsDefinition, + }, + { + name: ".m.rule.roomnotif", + inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleRoomNotifDefinition, + }, + // Default content rules + { + name: ".m.rule.contains_user_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`), + want: *mRuleContainsUserNameDefinition("myLocalUser"), + }, + // default underride rules + { + name: ".m.rule.call", + inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), + want: mRuleCallDefinition, + }, + { + name: ".m.rule.encrypted_room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleEncryptedRoomOneToOneDefinition, + }, + { + name: ".m.rule.room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleRoomOneToOneDefinition, + }, + { + name: ".m.rule.message", + inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`), + want: mRuleMessageDefinition, + }, + { + name: ".m.rule.encrypted", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`), + want: mRuleEncryptedDefinition, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := Rule{} + // unmarshal predefined push rules + err := json.Unmarshal(tc.inputBytes, &r) + assert.NoError(t, err) + assert.Equal(t, tc.want, r) + + // and reverse it to check we get the expected result + got, err := json.Marshal(r) + assert.NoError(t, err) + assert.Equal(t, string(got), string(tc.inputBytes)) + }) + + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go index 8da449a19..118bfae59 100644 --- a/internal/pushrules/default_underride.go +++ b/internal/pushrules/default_underride.go @@ -25,7 +25,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.call.invite", + Pattern: pointer("m.call.invite"), }, }, Actions: []*Action{ @@ -35,11 +35,6 @@ var ( Tweak: SoundTweak, Value: "ring", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedRoomOneToOneDefinition = Rule{ @@ -54,7 +49,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ @@ -64,11 +59,6 @@ var ( Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleRoomOneToOneDefinition = Rule{ @@ -83,20 +73,15 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, { Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, + Tweak: SoundTweak, + Value: "default", }, }, } @@ -108,16 +93,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedDefinition = Rule{ @@ -128,16 +108,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } ) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index df22cb042..fc8e0f174 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu case ContentKind: // TODO: "These configure behaviour for (unencrypted) messages // that match certain patterns." - Does that mean "content.body"? - return patternMatches("content.body", rule.Pattern, event) + if rule.Pattern == nil { + return false, nil + } + return patternMatches("content.body", *rule.Pattern, event) case RoomKind: return rule.RuleID == event.RoomID(), nil @@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: - return patternMatches(cond.Key, cond.Pattern, event) + if cond.Pattern == nil { + return false, fmt.Errorf("missing condition pattern") + } + return patternMatches(cond.Key, *cond.Pattern, event) case ContainsDisplayNameCondition: return patternMatches("content.body", ec.UserDisplayName(), event) @@ -145,6 +151,11 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua } func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + // It doesn't make sense for an empty pattern to match anything. + if pattern == "" { + return false, nil + } + re, err := globToRegexp(pattern) if err != nil { return false, err @@ -154,12 +165,20 @@ func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { return false, fmt.Errorf("parsing event: %w", err) } + // From the spec: + // "If the property specified by key is completely absent from + // the event, or does not have a string value, then the condition + // will not match, even if pattern is *." v, err := lookupMapPath(strings.Split(key, "."), eventMap) if err != nil { // An unknown path is a benign error that shouldn't stop rule // processing. It's just a non-match. return false, nil } + if _, ok := v.(string); !ok { + // A non-string never matches. + return false, nil + } return re.MatchString(fmt.Sprint(v)), nil } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index eabd02415..ca8ae5519 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) { {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, @@ -106,38 +106,44 @@ func TestConditionMatches(t *testing.T) { Name string Cond Condition EventJSON string - Want bool + WantMatch bool + WantErr bool }{ - {"empty", Condition{}, `{}`, false}, - {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + // Neither of these should match because `content` is not a full string match, + // and `content.body` is not a string value. + {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, + {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, - {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, - {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, + {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, - {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, - {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, - {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, - {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, - {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, - {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, - {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, - {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, - {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, - {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, - {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2}) - if err != nil { + if err != nil && !tst.WantErr { t.Fatalf("conditionMatches failed: %v", err) } - if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + if got != tst.WantMatch { + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name) } }) } @@ -161,9 +167,7 @@ func TestPatternMatches(t *testing.T) { }{ {"empty", "", "", `{}`, false}, - // Note that an empty pattern contains no wildcard characters, - // which implicitly means "*". - {"patternEmpty", "content", "", `{"content":{}}`, true}, + {"patternEmpty", "content", "", `{"content":{}}`, false}, {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, @@ -178,7 +182,7 @@ func TestPatternMatches(t *testing.T) { t.Fatalf("patternMatches failed: %v", err) } if got != tst.Want { - t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + t.Errorf("patternMatches: got %v, want %v on %s", got, tst.Want, tst.Name) } }) } diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go index bbed1f95f..98deaf132 100644 --- a/internal/pushrules/pushrules.go +++ b/internal/pushrules/pushrules.go @@ -36,18 +36,18 @@ type Rule struct { // around. Required. Enabled bool `json:"enabled"` + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions,omitempty"` + // Actions describe the desired outcome, should the rule // match. Required. Actions []*Action `json:"actions"` - // Conditions provide the rule's conditions for OverrideKind and - // UnderrideKind. Not allowed for other kinds. - Conditions []*Condition `json:"conditions"` - // Pattern is the body pattern to match for ContentKind. Required // for that kind. The interpretation is the same as that of // Condition.Pattern. - Pattern string `json:"pattern"` + Pattern *string `json:"pattern,omitempty"` } // Scope only has one valid value. See also AccountRuleSets. diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index 8ab4eab94..de8fe5cd0 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -11,22 +11,27 @@ import ( // kind and a tweaks map. Returns a nil map if it would have been // empty. func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { - kind := UnknownAction - tweaks := map[string]interface{}{} + var kind ActionKind + var tweaks map[string]interface{} for _, a := range as { - if a.Kind == SetTweakAction { - tweaks[string(a.Tweak)] = a.Value - continue - } - if kind != UnknownAction { - return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) - } - kind = a.Kind - } + switch a.Kind { + case DontNotifyAction: + // Don't bother processing any further + return DontNotifyAction, nil, nil - if len(tweaks) == 0 { - tweaks = nil + case SetTweakAction: + if tweaks == nil { + tweaks = map[string]interface{}{} + } + tweaks[string(a.Tweak)] = a.Value + + default: + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } } return kind, tweaks, nil @@ -123,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) { b = int(v) return cmp, nil } + +func pointer[t any](s t) *t { + return &s +} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go index a951c55a2..89f8243d9 100644 --- a/internal/pushrules/util_test.go +++ b/internal/pushrules/util_test.go @@ -17,6 +17,7 @@ func TestActionsToTweaks(t *testing.T) { {"empty", nil, UnknownAction, nil}, {"zero", []*Action{{}}, UnknownAction, nil}, {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyPrimaryDontNotify", []*Action{{Kind: DontNotifyAction}}, DontNotifyAction, nil}, {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, { diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index 5d260f0b9..f50c51bd7 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error { } case ContentKind: - if rule.Pattern == "" { + if rule.Pattern == nil { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + if rule.Pattern != nil && *rule.Pattern == "" { errs = append(errs, fmt.Errorf("missing content rule pattern")) } diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go index b276eb551..966e46259 100644 --- a/internal/pushrules/validate_test.go +++ b/internal/pushrules/validate_test.go @@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) { Rule Rule WantErrString string }{ - {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, - {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, - {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, - {"noActions", OverrideKind, Rule{}, "missing actions"}, - {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, - {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, - {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, - {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, - {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + {Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"}, + {Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"}, + {Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"}, + {Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"}, + {Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"}, + {Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"}, + {Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"}, + {Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 19483b268..81c055edd 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -124,6 +124,11 @@ type QueryProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } +// ExecProvider defines the interface for querys used by RunLimitedVariablesExec. +type ExecProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement // SQLlite can handle. See https://www.sqlite.org/limits.html for more information. const SQLite3MaxVariables = 999 @@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide return nil } +// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + _, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error") + return err + } + start = start + n + } + return nil +} + // StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. type StatementList []struct { Statement **sql.Stmt diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/internal/tracing.go b/internal/tracing.go new file mode 100644 index 000000000..4e062aed3 --- /dev/null +++ b/internal/tracing.go @@ -0,0 +1,64 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "runtime/trace" + + "github.com/opentracing/opentracing-go" +) + +type Trace struct { + span opentracing.Span + region *trace.Region + task *trace.Task +} + +func StartTask(inCtx context.Context, name string) (Trace, context.Context) { + ctx, task := trace.NewTask(inCtx, name) + span, ctx := opentracing.StartSpanFromContext(ctx, name) + return Trace{ + span: span, + task: task, + }, ctx +} + +func StartRegion(inCtx context.Context, name string) (Trace, context.Context) { + region := trace.StartRegion(inCtx, name) + span, ctx := opentracing.StartSpanFromContext(inCtx, name) + return Trace{ + span: span, + region: region, + }, ctx +} + +func (t Trace) EndRegion() { + t.span.Finish() + if t.region != nil { + t.region.End() + } +} + +func (t Trace) EndTask() { + t.span.Finish() + if t.task != nil { + t.task.End() + } +} + +func (t Trace) SetTag(key string, value any) { + t.span.SetTag(key, value) +} diff --git a/internal/tracing_test.go b/internal/tracing_test.go new file mode 100644 index 000000000..582f50c3a --- /dev/null +++ b/internal/tracing_test.go @@ -0,0 +1,25 @@ +package internal + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTracing(t *testing.T) { + inCtx := context.Background() + + task, ctx := StartTask(inCtx, "testing") + assert.NotNil(t, ctx) + assert.NotNil(t, task) + assert.NotEqual(t, inCtx, ctx) + task.SetTag("key", "value") + + region, ctx2 := StartRegion(ctx, "testing") + assert.NotNil(t, ctx) + assert.NotNil(t, region) + assert.NotEqual(t, ctx, ctx2) + defer task.EndTask() + defer region.EndRegion() +} diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go new file mode 100644 index 000000000..13b00af50 --- /dev/null +++ b/internal/transactionrequest.go @@ -0,0 +1,356 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/dendrite/roomserver/api" + syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" +) + +var ( + PDUCountTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_pdus", + Help: "Number of incoming PDUs from remote servers with labels for success", + }, + []string{"status"}, // 'success' or 'total' + ) + EDUCountTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_edus", + Help: "Number of incoming EDUs from remote servers", + }, + ) +) + +type TxnReq struct { + gomatrixserverlib.Transaction + rsAPI api.FederationRoomserverAPI + userAPI userAPI.FederationUserAPI + ourServerName gomatrixserverlib.ServerName + keys gomatrixserverlib.JSONVerifier + roomsMu *MutexByRoom + producer *producers.SyncAPIProducer + inboundPresenceEnabled bool +} + +func NewTxnReq( + rsAPI api.FederationRoomserverAPI, + userAPI userAPI.FederationUserAPI, + ourServerName gomatrixserverlib.ServerName, + keys gomatrixserverlib.JSONVerifier, + roomsMu *MutexByRoom, + producer *producers.SyncAPIProducer, + inboundPresenceEnabled bool, + pdus []json.RawMessage, + edus []gomatrixserverlib.EDU, + origin gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + destination gomatrixserverlib.ServerName, +) TxnReq { + t := TxnReq{ + rsAPI: rsAPI, + userAPI: userAPI, + ourServerName: ourServerName, + keys: keys, + roomsMu: roomsMu, + producer: producer, + inboundPresenceEnabled: inboundPresenceEnabled, + } + + t.PDUs = pdus + t.EDUs = edus + t.Origin = origin + t.TransactionID = transactionID + t.Destination = destination + + return t +} + +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if t.producer != nil { + t.processEDUs(ctx) + } + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } + + for _, pdu := range t.PDUs { + PDUCountTotal.WithLabelValues("total").Inc() + var header struct { + RoomID string `json:"room_id"` + } + if err := json.Unmarshal(pdu, &header); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue + } + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: "Forbidden by server ACLs", + } + continue + } + if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Destination, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + PDUCountTotal.WithLabelValues("success").Inc() + } + + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +// nolint:gocyclo +func (t *TxnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { + EDUCountTotal.Inc() + switch e.Type { + case gomatrixserverlib.MTyping: + // https://matrix.org/docs/spec/server_server/latest#typing-notifications + var typingPayload struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Typing bool `json:"typing"` + } + if err := json.Unmarshal(e.Content, &typingPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") + } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to JetStream") + } + } + } + case gomatrixserverlib.MDeviceListUpdate: + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]types.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to JetStream") + continue + } + } + } + case types.MSigningKeyUpdate: + if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + logrus.WithError(err).Errorf("Failed to process signing key update") + } + case gomatrixserverlib.MPresence: + if t.inboundPresenceEnabled { + if err := t.processPresence(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process presence update") + } + } + default: + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") + } + } +} + +// processPresence handles m.receipt events +func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { + payload := types.Presence{} + if err := json.Unmarshal(e.Content, &payload); err != nil { + return err + } + for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + presence, ok := syncTypes.PresenceFromString(content.Presence) + if !ok { + continue + } + if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { + return err + } + } + return nil +} + +// processReceiptEvent sends receipt events to JetStream +func (t *TxnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } + // store every event + for _, eventID := range eventIDs { + if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go new file mode 100644 index 000000000..8597ae24b --- /dev/null +++ b/internal/transactionrequest_test.go @@ -0,0 +1,820 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/producers" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + keyAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`) + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) +) + +type FakeRsAPI struct { + rsAPI.RoomserverInternalAPI + shouldFailQuery bool + bannedFromRoom bool + shouldEventsFail bool +} + +func (r *FakeRsAPI) QueryRoomVersionForRoom( + ctx context.Context, + req *rsAPI.QueryRoomVersionForRoomRequest, + res *rsAPI.QueryRoomVersionForRoomResponse, +) error { + if r.shouldFailQuery { + return fmt.Errorf("Failure") + } + res.RoomVersion = gomatrixserverlib.RoomVersionV10 + return nil +} + +func (r *FakeRsAPI) QueryServerBannedFromRoom( + ctx context.Context, + req *rsAPI.QueryServerBannedFromRoomRequest, + res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + if r.bannedFromRoom { + res.Banned = true + } else { + res.Banned = false + } + return nil +} + +func (r *FakeRsAPI) InputRoomEvents( + ctx context.Context, + req *rsAPI.InputRoomEventsRequest, + res *rsAPI.InputRoomEventsResponse, +) error { + if r.shouldEventsFail { + return fmt.Errorf("Failure") + } + return nil +} + +func TestEmptyTransactionRequest(t *testing.T) { + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDU(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUs(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestBadPDU(t *testing.T) { + pdu := json.RawMessage("{\"room_id\":\"asdf\"}") + pdu2 := json.RawMessage("\"roomid\":\"asdf\"") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUSendFail(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { + cfg := &config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + SingleDatabase: true, + }) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, + UserAPI: nil, + } + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.typing"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badUser := gomatrixserverlib.EDU{Type: "m.receipt"} + if badUser.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "johnkaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badDomain := gomatrixserverlib.EDU{Type: "m.receipt"} + if badDomain.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:bad.domain": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + edus := []gomatrixserverlib.EDU{badEDU, badUser, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.presence"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu}) + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testRoomserverAPI struct { + rsAPI.RoomserverInternalAPI + inputRoomEvents []rsAPI.InputRoomEvent + queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse + queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse + queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *rsAPI.InputRoomEventsRequest, + response *rsAPI.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *rsAPI.QueryLatestEventsAndStateRequest, + response *rsAPI.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *rsAPI.QueryStateAfterEventsRequest, + response *rsAPI.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *rsAPI.QueryEventsByIDRequest, + response *rsAPI.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query if a server is joined to a room +func (t *testRoomserverAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *rsAPI.QueryServerJoinedToRoomRequest, + response *rsAPI.QueryServerJoinedToRoomResponse, +) error { + response.RoomExists = true + response.IsInRoom = true + return nil +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *rsAPI.QueryRoomVersionForRoomRequest, + response *rsAPI.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom( + ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + res.Banned = false + return nil +} + +func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq { + t := NewTxnReq( + rsAPI, + nil, + "", + &test.NopJSONVerifier{}, + NewMutexByRoom(), + nil, + false, + pdus, + nil, + testOrigin, + gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())), + testDestination) + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return &t +} + +func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) { + res, err := txn.ProcessTransaction(context.Background()) + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} diff --git a/internal/validate.go b/internal/validate.go new file mode 100644 index 000000000..0461b897e --- /dev/null +++ b/internal/validate.go @@ -0,0 +1,110 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "errors" + "fmt" + "net/http" + "regexp" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +const ( + maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain + + minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 +) + +var ( + ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength) + ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength) + ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength) + ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='") + ErrUsernameUnderscore = errors.New("username cannot start with a '_'") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) +) + +// ValidatePassword returns an error if the password is invalid +func ValidatePassword(password string) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if len(password) > maxPasswordLength { + return ErrPasswordTooLong + } else if len(password) > 0 && len(password) < minPasswordLength { + return ErrPasswordWeak + } + return nil +} + +// PasswordResponse returns a util.JSONResponse for a given error, if any. +func PasswordResponse(err error) *util.JSONResponse { + switch err { + case ErrPasswordWeak: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + } + case ErrPasswordTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), + } + } + return nil +} + +// ValidateUsername returns an error if the username is invalid +func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } else if localpart[0] == '_' { // Regex checks its not a zero length string + return ErrUsernameUnderscore + } + return nil +} + +// UsernameResponse returns a util.JSONResponse for the given error, if any. +func UsernameResponse(err error) *util.JSONResponse { + switch err { + case ErrUsernameTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + case ErrUsernameInvalid, ErrUsernameUnderscore: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + } + return nil +} + +// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service +func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } + return nil +} diff --git a/internal/validate_test.go b/internal/validate_test.go new file mode 100644 index 000000000..d0ad04707 --- /dev/null +++ b/internal/validate_test.go @@ -0,0 +1,170 @@ +package internal + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantError error + wantJSON *util.JSONResponse + }{ + { + name: "password too short", + password: "shortpw", + wantError: ErrPasswordWeak, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantError: ErrPasswordTooLong, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + }, + { + name: "password OK", + password: util.RandomString(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidatePassword(tt.password) + if !reflect.DeepEqual(gotErr, tt.wantError) { + t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError) + } + + if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) { + t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON) + } + }) + } +} + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + wantErr error + wantJSON *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + wantErr: ErrUsernameTooLong, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + wantErr: ErrUsernameUnderscore, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + { + name: "complex username", + localpart: "f00_bar-baz.=40/", + domain: "localhost", + }, + { + name: "rejects emoji username 💥", + localpart: "💥", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "special characters are allowed", + localpart: "/dev/null", + domain: "localhost", + }, + { + name: "special characters are allowed 2", + localpart: "i_am_allowed=1", + domain: "localhost", + }, + { + name: "not all special characters are allowed", + localpart: "notallowed#", // contains # + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username containing numbers", + localpart: "hello1337", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidateUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + + // Application services are allowed usernames starting with an underscore + if tt.wantErr == ErrUsernameUnderscore { + return + } + gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + }) + } +} diff --git a/internal/version.go b/internal/version.go index 685237b9e..907547589 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 10 - VersionPatch = 8 + VersionMinor = 12 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/README.md b/keyserver/README.md deleted file mode 100644 index fd9f37d27..000000000 --- a/keyserver/README.md +++ /dev/null @@ -1,19 +0,0 @@ -## Key Server - -This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID. - -Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API. - -### Internal APIs -- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s). -- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls. -- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers. -- A topic which emits identity keys every time there is a change (addition or deletion). - -### Endpoint mappings -- Client API maps `/keys/upload` to `PerformUploadKeys`. -- Client API maps `/keys/query` to `QueryKeys`. -- Client API maps `/keys/claim` to `PerformClaimKeys`. -- Federation API maps `/user/keys/query` to `QueryKeys`. -- Federation API maps `/user/keys/claim` to `PerformClaimKeys`. -- Sync API maps `/keys/changes` to consuming from the Kafka topic. diff --git a/keyserver/api/api.go b/keyserver/api/api.go deleted file mode 100644 index 14fced3e8..000000000 --- a/keyserver/api/api.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "bytes" - "context" - "encoding/json" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/types" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI interface { - SyncKeyAPI - ClientKeyAPI - FederationKeyAPI - UserKeyAPI - - // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.KeyserverUserAPI) -} - -// API functions required by the clientapi -type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -// API functions required by the userapi -type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error -} - -// API functions required by the syncapi -type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error -} - -// KeyError is returned if there was a problem performing/querying the server -type KeyError struct { - Err string `json:"error"` - IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE - IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM - IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM -} - -func (k *KeyError) Error() string { - return k.Err -} - -type DeviceMessageType int - -const ( - TypeDeviceKeyUpdate DeviceMessageType = iota - TypeCrossSigningUpdate -) - -// DeviceMessage represents the message produced into Kafka by the key server. -type DeviceMessage struct { - Type DeviceMessageType `json:"Type,omitempty"` - *DeviceKeys `json:"DeviceKeys,omitempty"` - *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` - // A monotonically increasing number which represents device changes for this user. - StreamID int64 - DeviceChangeID int64 -} - -// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log -type OutputCrossSigningKeyUpdate struct { - CrossSigningKeyUpdate `json:"signing_keys"` -} - -type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` -} - -// DeviceKeysEqual returns true if the device keys updates contain the -// same display name and key JSON. This will return false if either of -// the updates is not a device keys update, or if the user ID/device ID -// differ between the two. -func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { - if m1.DeviceKeys == nil || m2.DeviceKeys == nil { - return false - } - if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { - return false - } - if m1.DisplayName != m2.DisplayName { - return false // different display names - } - if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { - return false // either is empty - } - return bytes.Equal(m1.KeyJSON, m2.KeyJSON) -} - -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { - return DeviceMessage{ - DeviceKeys: k, - StreamID: streamID, - } -} - -// OneTimeKeys represents a set of one-time keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type OneTimeKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // A map of algorithm:key_id => key JSON - KeyJSON map[string]json.RawMessage -} - -// Split a key in KeyJSON into algorithm and key ID -func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { - segments := strings.Split(keyIDWithAlgo, ":") - return segments[0], segments[1] -} - -// OneTimeKeysCount represents the counts of one-time keys for a single device -type OneTimeKeysCount struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // algorithm to count e.g: - // { - // "curve25519": 10, - // "signed_curve25519": 20 - // } - KeyCount map[string]int -} - -// PerformUploadKeysRequest is the request to PerformUploadKeys -type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys - // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update - // the display name for their respective device, and NOT to modify the keys. The key - // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. - // Without this flag, requests to modify device display names would delete device keys. - OnlyDisplayNameUpdates bool -} - -// PerformUploadKeysResponse is the response to PerformUploadKeys -type PerformUploadKeysResponse struct { - // A fatal error when processing e.g database failures - Error *KeyError - // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount -} - -// PerformDeleteKeysRequest asks the keyserver to forget about certain -// keys, and signatures related to those keys. -type PerformDeleteKeysRequest struct { - UserID string - KeyIDs []gomatrixserverlib.KeyID -} - -// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. -type PerformDeleteKeysResponse struct { - Error *KeyError -} - -// KeyError sets a key error field on KeyErrors -func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { - if r.KeyErrors[userID] == nil { - r.KeyErrors[userID] = make(map[string]*KeyError) - } - r.KeyErrors[userID][deviceID] = err -} - -type PerformClaimKeysRequest struct { - // Map of user_id to device_id to algorithm name - OneTimeKeys map[string]map[string]string - Timeout time.Duration -} - -type PerformClaimKeysResponse struct { - // Map of user_id to device_id to algorithm:key_id to key JSON - OneTimeKeys map[string]map[string]map[string]json.RawMessage - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Set if there was a fatal error processing this action - Error *KeyError -} - -type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys - // The user that uploaded the key, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceKeysResponse struct { - Error *KeyError -} - -type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice - // The user that uploaded the sig, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceSignaturesResponse struct { - Error *KeyError -} - -type QueryKeysRequest struct { - // The user ID asking for the keys, e.g. if from a client API request. - // Will not be populated if the key request came from federation. - UserID string - // Maps user IDs to a list of devices - UserToDevices map[string][]string - Timeout time.Duration -} - -type QueryKeysResponse struct { - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage - // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // Set if there was a fatal error processing this query - Error *KeyError -} - -type QueryKeyChangesRequest struct { - // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning - Offset int64 - // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). - ToOffset int64 -} - -type QueryKeyChangesResponse struct { - // The set of users who have had their keys change. - UserIDs []string - // The latest offset represented in this response. - Offset int64 - // Set if there was a problem handling the request. - Error *KeyError -} - -type QueryOneTimeKeysRequest struct { - // The local user to query OTK counts for - UserID string - // The device to query OTK counts for - DeviceID string -} - -type QueryOneTimeKeysResponse struct { - // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError -} - -type QueryDeviceMessagesRequest struct { - UserID string -} - -type QueryDeviceMessagesResponse struct { - // The latest stream ID - StreamID int64 - Devices []DeviceMessage - Error *KeyError -} - -type QuerySignaturesRequest struct { - // A map of target user ID -> target key/device IDs to retrieve signatures for - TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` -} - -type QuerySignaturesResponse struct { - // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures - Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // The request error, if any - Error *KeyError -} - -type PerformMarkAsStaleRequest struct { - UserID string - Domain gomatrixserverlib.ServerName - DeviceID string -} diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go deleted file mode 100644 index 75d537d9c..000000000 --- a/keyserver/inthttp/client.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -// HTTP paths for the internal HTTP APIs -const ( - InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate" - PerformUploadKeysPath = "/keyserver/performUploadKeys" - PerformClaimKeysPath = "/keyserver/performClaimKeys" - PerformDeleteKeysPath = "/keyserver/performDeleteKeys" - PerformUploadDeviceKeysPath = "/keyserver/performUploadDeviceKeys" - PerformUploadDeviceSignaturesPath = "/keyserver/performUploadDeviceSignatures" - QueryKeysPath = "/keyserver/queryKeys" - QueryKeyChangesPath = "/keyserver/queryKeyChanges" - QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" - QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" - QuerySignaturesPath = "/keyserver/querySignatures" - PerformMarkAsStalePath = "/keyserver/markAsStale" -) - -// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewKeyServerClient( - apiURL string, - httpClient *http.Client, -) (api.KeyInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewKeyServerClient: httpClient is ") - } - return &httpKeyInternalAPI{ - apiURL: apiURL, - httpClient: httpClient, - }, nil -} - -type httpKeyInternalAPI struct { - apiURL string - httpClient *http.Client -} - -func (h *httpKeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - // no-op: doesn't need it -} - -func (h *httpKeyInternalAPI) PerformClaimKeys( - ctx context.Context, - request *api.PerformClaimKeysRequest, - response *api.PerformClaimKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformClaimKeys", h.apiURL+PerformClaimKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformDeleteKeys( - ctx context.Context, - request *api.PerformDeleteKeysRequest, - response *api.PerformDeleteKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadKeys( - ctx context.Context, - request *api.PerformUploadKeysRequest, - response *api.PerformUploadKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadKeys", h.apiURL+PerformUploadKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryKeys( - ctx context.Context, - request *api.QueryKeysRequest, - response *api.QueryKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeys", h.apiURL+QueryKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryOneTimeKeys( - ctx context.Context, - request *api.QueryOneTimeKeysRequest, - response *api.QueryOneTimeKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryDeviceMessages( - ctx context.Context, - request *api.QueryDeviceMessagesRequest, - response *api.QueryDeviceMessagesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QueryKeyChanges( - ctx context.Context, - request *api.QueryKeyChangesRequest, - response *api.QueryKeyChangesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeyChanges", h.apiURL+QueryKeyChangesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadDeviceKeys( - ctx context.Context, - request *api.PerformUploadDeviceKeysRequest, - response *api.PerformUploadDeviceKeysResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures( - ctx context.Context, - request *api.PerformUploadDeviceSignaturesRequest, - response *api.PerformUploadDeviceSignaturesResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) QuerySignatures( - ctx context.Context, - request *api.QuerySignaturesRequest, - response *api.QuerySignaturesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySignatures", h.apiURL+QuerySignaturesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpKeyInternalAPI) PerformMarkAsStaleIfNeeded( - ctx context.Context, - request *api.PerformMarkAsStaleRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "MarkAsStale", h.apiURL+PerformMarkAsStalePath, - h.httpClient, ctx, request, response, - ) -} diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go deleted file mode 100644 index 7af0ff6e5..000000000 --- a/keyserver/inthttp/server.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" -) - -func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { - internalAPIMux.Handle( - PerformClaimKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys), - ) - - internalAPIMux.Handle( - PerformDeleteKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys), - ) - - internalAPIMux.Handle( - PerformUploadKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys), - ) - - internalAPIMux.Handle( - PerformUploadDeviceKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys), - ) - - internalAPIMux.Handle( - PerformUploadDeviceSignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures), - ) - - internalAPIMux.Handle( - QueryKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys), - ) - - internalAPIMux.Handle( - QueryOneTimeKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys), - ) - - internalAPIMux.Handle( - QueryDeviceMessagesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages), - ) - - internalAPIMux.Handle( - QueryKeyChangesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges), - ) - - internalAPIMux.Handle( - QuerySignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), - ) - - internalAPIMux.Handle( - PerformMarkAsStalePath, - httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", s.PerformMarkAsStaleIfNeeded), - ) -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go deleted file mode 100644 index a86c2da4e..000000000 --- a/keyserver/keyserver.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keyserver - -import ( - "github.com/gorilla/mux" - "github.com/sirupsen/logrus" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/inthttp" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" -) - -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { - inthttp.AddRoutes(router, intAPI) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, -) api.KeyInternalAPI { - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - db, err := storage.NewDatabase(base, &cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to key server database") - } - keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), - JetStream: js, - DB: db, - } - ap := &internal.KeyInternalAPI{ - DB: db, - Cfg: cfg, - FedClient: fedClient, - Producer: keyChangeProducer, - } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable - ap.Updater = updater - go func() { - if err := updater.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } - }() - - dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, - ) - if err := dlConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start device list consumer") - } - - sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, ap, - ) - if err := sigConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start signing key consumer") - } - - return ap -} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go deleted file mode 100644 index 242e16a06..000000000 --- a/keyserver/storage/interface.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination - // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. - ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - - // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - - // OneTimeKeysCount returns a count of all OTKs for this device. - OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). - // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. - // Returns an error if there was a problem storing the keys. - StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. - // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - - // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying - // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error - - // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key - // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. - ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of types.OffsetNewest means no upper limit. - // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) - CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error -} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go deleted file mode 100644 index 35e630559..000000000 --- a/keyserver/storage/postgres/storage.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) - if err != nil { - return nil, err - } - otk, err := NewPostgresOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewPostgresDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewPostgresKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewPostgresStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewPostgresCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewPostgresCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go deleted file mode 100644 index 5beeed0f1..000000000 --- a/keyserver/storage/shared/storage.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges - StaleDeviceListsTable tables.StaleDeviceLists - CrossSigningKeysTable tables.CrossSigningKeys - CrossSigningSigsTable tables.CrossSigningSigs -} - -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) -} - -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return err - }) - return -} - -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) -} - -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) -} - -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) - if err != nil { - return false, err - } - return count == len(prevIDs), nil -} - -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, userID := range clearUserIDs { - err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) - if err != nil { - return err - } - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { - // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int64) - for _, k := range keys { - userIDToStreamID[k.UserID] = 0 - } - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID := range userIDToStreamID { - streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) - if err != nil { - return err - } - userIDToStreamID[userID] = streamID - } - // set the stream IDs for each key - for i := range keys { - k := keys[i] - userIDToStreamID[k.UserID]++ // start stream from 1 - k.StreamID = userIDToStreamID[k.UserID] - keys[i] = k - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) -} - -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { - var result []api.OneTimeKeys - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID, deviceToAlgo := range userToDeviceToAlgorithm { - for deviceID, algo := range deviceToAlgo { - keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) - if err != nil { - return err - } - if keyJSON != nil { - result = append(result, api.OneTimeKeys{ - UserID: userID, - DeviceID: deviceID, - KeyJSON: keyJSON, - }) - } - } - } - return nil - }) - return result, err -} - -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { - err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) - return err - }) - return -} - -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) - }) -} - -// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying -// cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, deviceID := range deviceIDs { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) - } - if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) - } - } - return nil - }) -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { - keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) - if err != nil { - return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) - } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: key, - }, - } - sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) - if err != nil { - continue - } - for sigUserID, forSigUserID := range sigMap { - if userID != sigUserID { - continue - } - if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sigKeyID, sigBytes := range forSigUserID { - result.Signatures[sigUserID][sigKeyID] = sigBytes - } - } - results[purpose] = result - } - return results, nil -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) -} - -// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { - return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) -} - -// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { - return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) - } - } - return nil - }) -} - -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( - ctx context.Context, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) - } - return nil - }) -} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go deleted file mode 100644 index 873fe3e24..000000000 --- a/keyserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) - if err != nil { - return nil, err - } - otk, err := NewSqliteOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewSqliteDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewSqliteCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewSqliteCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go deleted file mode 100644 index ab6a35401..000000000 --- a/keyserver/storage/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go deleted file mode 100644 index e7a2af7c2..000000000 --- a/keyserver/storage/storage_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package storage_test - -import ( - "context" - "reflect" - "sync" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ctx = context.Background() - -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) - if err != nil { - t.Fatalf("failed to create new database: %v", err) - } - return db, close -} - -func MustNotError(t *testing.T, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("operation failed: %s", err) -} - -func TestKeyChanges(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesNoDupes(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesUpperLimit(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -var dbLock sync.Mutex -var deviceArray = []string{"AAA", "another_device"} - -// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, -// and that they are returned correctly when querying for device keys. -func TestDeviceKeysStreamIDGeneration(t *testing.T) { - var err error - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - dbLock.Lock() - defer dbLock.Unlock() - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) - - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) - } - } - }) -} diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go deleted file mode 100644 index 75c9053e8..000000000 --- a/keyserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go deleted file mode 100644 index 37a010a7c..000000000 --- a/keyserver/storage/tables/interface.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type OneTimeKeys interface { - SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. - // Returns an empty map if the key does not exist. - SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error -} - -type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) - CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error - DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error -} - -type KeyChanges interface { - InsertKeyChange(ctx context.Context, userID string) (int64, error) - // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - Prepare() error -} - -type StaleDeviceLists interface { - InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) -} - -type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error -} - -type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error -} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 9dcfa955f..50af2f884 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -108,13 +108,16 @@ func makeDownloadAPI( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { - counterVec := promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + var counterVec *prometheus.CounterVec + if cfg.Matrix.Metrics.Enabled { + counterVec = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: name, + Help: "Total number of media_api requests for either thumbnails or full downloads", + }, + []string{"code"}, + ) + } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -166,5 +169,12 @@ func makeDownloadAPI( vars["downloadName"], ) } - return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + + var handlerFunc http.HandlerFunc + if counterVec != nil { + handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + } else { + handlerFunc = http.HandlerFunc(httpHandler) + } + return handlerFunc } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 01e87ec8a..f6d003a44 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface { } type QueryEventsAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -71,7 +72,8 @@ type SyncRoomserverAPI interface { QueryBulkStateContentAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -108,7 +110,8 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -150,6 +153,7 @@ type ClientRoomserverAPI interface { PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error @@ -165,6 +169,7 @@ type ClientRoomserverAPI interface { type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI + KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error @@ -180,6 +185,8 @@ type FederationRoomserverAPI interface { QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error // Query to get state and auth chain for a (potentially hypothetical) event. // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate @@ -191,7 +198,7 @@ type FederationRoomserverAPI interface { // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error @@ -199,3 +206,7 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error } + +type KeyserverRoomserverAPI interface { + QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go deleted file mode 100644 index 342a3904c..000000000 --- a/roomserver/api/api_trace.go +++ /dev/null @@ -1,411 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsAPI "github.com/matrix-org/dendrite/federationapi/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the -// complete request/response/error -type RoomserverInternalAPITrace struct { - Impl RoomserverInternalAPI -} - -func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { - t.Impl.SetFederationAPI(fsAPI, keyRing) -} - -func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { - t.Impl.SetAppserviceAPI(asAPI) -} - -func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.RoomserverUserAPI) { - t.Impl.SetUserAPI(userAPI) -} - -func (t *RoomserverInternalAPITrace) InputRoomEvents( - ctx context.Context, - req *InputRoomEventsRequest, - res *InputRoomEventsResponse, -) error { - err := t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformInvite( - ctx context.Context, - req *PerformInviteRequest, - res *PerformInviteResponse, -) error { - err := t.Impl.PerformInvite(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformPeek( - ctx context.Context, - req *PerformPeekRequest, - res *PerformPeekResponse, -) error { - err := t.Impl.PerformPeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformUnpeek( - ctx context.Context, - req *PerformUnpeekRequest, - res *PerformUnpeekResponse, -) error { - err := t.Impl.PerformUnpeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformRoomUpgrade( - ctx context.Context, - req *PerformRoomUpgradeRequest, - res *PerformRoomUpgradeResponse, -) error { - err := t.Impl.PerformRoomUpgrade(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformJoin( - ctx context.Context, - req *PerformJoinRequest, - res *PerformJoinResponse, -) error { - err := t.Impl.PerformJoin(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformLeave( - ctx context.Context, - req *PerformLeaveRequest, - res *PerformLeaveResponse, -) error { - err := t.Impl.PerformLeave(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformLeave req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformPublish( - ctx context.Context, - req *PerformPublishRequest, - res *PerformPublishResponse, -) error { - err := t.Impl.PerformPublish(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( - ctx context.Context, - req *PerformAdminEvacuateRoomRequest, - res *PerformAdminEvacuateRoomResponse, -) error { - err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( - ctx context.Context, - req *PerformAdminEvacuateUserRequest, - res *PerformAdminEvacuateUserResponse, -) error { - err := t.Impl.PerformAdminEvacuateUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformAdminDownloadState( - ctx context.Context, - req *PerformAdminDownloadStateRequest, - res *PerformAdminDownloadStateResponse, -) error { - err := t.Impl.PerformAdminDownloadState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformAdminDownloadState req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformInboundPeek( - ctx context.Context, - req *PerformInboundPeekRequest, - res *PerformInboundPeekResponse, -) error { - err := t.Impl.PerformInboundPeek(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryPublishedRooms( - ctx context.Context, - req *QueryPublishedRoomsRequest, - res *QueryPublishedRoomsResponse, -) error { - err := t.Impl.QueryPublishedRooms(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryPublishedRooms req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryLatestEventsAndState( - ctx context.Context, - req *QueryLatestEventsAndStateRequest, - res *QueryLatestEventsAndStateResponse, -) error { - err := t.Impl.QueryLatestEventsAndState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryLatestEventsAndState req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( - ctx context.Context, - req *QueryStateAfterEventsRequest, - res *QueryStateAfterEventsResponse, -) error { - err := t.Impl.QueryStateAfterEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryStateAfterEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryEventsByID( - ctx context.Context, - req *QueryEventsByIDRequest, - res *QueryEventsByIDResponse, -) error { - err := t.Impl.QueryEventsByID(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryEventsByID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipForUser( - ctx context.Context, - req *QueryMembershipForUserRequest, - res *QueryMembershipForUserResponse, -) error { - err := t.Impl.QueryMembershipForUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipForUser req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom( - ctx context.Context, - req *QueryMembershipsForRoomRequest, - res *QueryMembershipsForRoomResponse, -) error { - err := t.Impl.QueryMembershipsForRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipsForRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryServerJoinedToRoom( - ctx context.Context, - req *QueryServerJoinedToRoomRequest, - res *QueryServerJoinedToRoomResponse, -) error { - err := t.Impl.QueryServerJoinedToRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerJoinedToRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent( - ctx context.Context, - req *QueryServerAllowedToSeeEventRequest, - res *QueryServerAllowedToSeeEventResponse, -) error { - err := t.Impl.QueryServerAllowedToSeeEvent(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerAllowedToSeeEvent req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMissingEvents( - ctx context.Context, - req *QueryMissingEventsRequest, - res *QueryMissingEventsResponse, -) error { - err := t.Impl.QueryMissingEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMissingEvents req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryStateAndAuthChain( - ctx context.Context, - req *QueryStateAndAuthChainRequest, - res *QueryStateAndAuthChainResponse, -) error { - err := t.Impl.QueryStateAndAuthChain(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryStateAndAuthChain req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformBackfill( - ctx context.Context, - req *PerformBackfillRequest, - res *PerformBackfillResponse, -) error { - err := t.Impl.PerformBackfill(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformBackfill req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) PerformForget( - ctx context.Context, - req *PerformForgetRequest, - res *PerformForgetResponse, -) error { - err := t.Impl.PerformForget(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("PerformForget req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( - ctx context.Context, - req *QueryRoomVersionCapabilitiesRequest, - res *QueryRoomVersionCapabilitiesResponse, -) error { - err := t.Impl.QueryRoomVersionCapabilities(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionCapabilities req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRoomVersionForRoom( - ctx context.Context, - req *QueryRoomVersionForRoomRequest, - res *QueryRoomVersionForRoomResponse, -) error { - err := t.Impl.QueryRoomVersionForRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionForRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) SetRoomAlias( - ctx context.Context, - req *SetRoomAliasRequest, - res *SetRoomAliasResponse, -) error { - err := t.Impl.SetRoomAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("SetRoomAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) GetRoomIDForAlias( - ctx context.Context, - req *GetRoomIDForAliasRequest, - res *GetRoomIDForAliasResponse, -) error { - err := t.Impl.GetRoomIDForAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("GetRoomIDForAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) GetAliasesForRoomID( - ctx context.Context, - req *GetAliasesForRoomIDRequest, - res *GetAliasesForRoomIDResponse, -) error { - err := t.Impl.GetAliasesForRoomID(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("GetAliasesForRoomID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) RemoveRoomAlias( - ctx context.Context, - req *RemoveRoomAliasRequest, - res *RemoveRoomAliasResponse, -) error { - err := t.Impl.RemoveRoomAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("RemoveRoomAlias req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error { - err := t.Impl.QueryCurrentState(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryRoomsForUser retrieves a list of room IDs matching the given query. -func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error { - err := t.Impl.QueryRoomsForUser(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryBulkStateContent does a bulk query for state event content in the given rooms. -func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error { - err := t.Impl.QueryBulkStateContent(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res)) - return err -} - -// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. -func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error { - err := t.Impl.QuerySharedUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryKnownUsers returns a list of users that we know about from our joined rooms. -func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error { - err := t.Impl.QueryKnownUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res)) - return err -} - -// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. -func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { - err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryAuthChain( - ctx context.Context, - request *QueryAuthChainRequest, - response *QueryAuthChainResponse, -) error { - err := t.Impl.QueryAuthChain(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryAuthChain req=%+v res=%+v", js(request), js(response)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( - ctx context.Context, - request *QueryRestrictedJoinAllowedRequest, - response *QueryRestrictedJoinAllowedResponse, -) error { - err := t.Impl.QueryRestrictedJoinAllowed(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryRestrictedJoinAllowed req=%+v res=%+v", js(request), js(response)) - return err -} - -func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent( - ctx context.Context, - request *QueryMembershipAtEventRequest, - response *QueryMembershipAtEventResponse, -) error { - err := t.Impl.QueryMembershipAtEvent(ctx, request, response) - util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response)) - return err -} - -func js(thing interface{}) string { - b, err := json.Marshal(thing) - if err != nil { - return fmt.Sprintf("Marshal error:%s", err) - } - return string(b) -} diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 36d0625c7..0c0f52c45 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -55,6 +55,8 @@ const ( OutputTypeNewInboundPeek OutputType = "new_inbound_peek" // OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek OutputTypeRetirePeek OutputType = "retire_peek" + // OutputTypePurgeRoom indicates the event is an OutputPurgeRoom + OutputTypePurgeRoom OutputType = "purge_room" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -78,6 +80,8 @@ type OutputEvent struct { NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"` // The content of event with type OutputTypeRetirePeek RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"` + // The content of the event with type OutputPurgeRoom + PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"` } // Type of the OutputNewRoomEvent. @@ -257,3 +261,7 @@ type OutputRetirePeek struct { UserID string DeviceID string } + +type OutputPurgeRoom struct { + RoomID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9c..83cb0460a 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` @@ -240,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct { Error *PerformError } +type PerformAdminPurgeRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminPurgeRoomResponse struct { + Error *PerformError `json:"error,omitempty"` +} + type PerformAdminDownloadStateRequest struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b62907f3c..24722db0b 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct { // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { + // The roomID to query events for. If this is empty, we first try to fetch the roomID from the database + // as this is needed for further processing/parsing events. + RoomID string `json:"room_id"` // The event IDs to look up. EventIDs []string `json:"event_ids"` } @@ -433,7 +436,7 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { return nil } -// QueryMembershipAtEventRequest requests the membership events for a user +// QueryMembershipAtEventRequest requests the membership event for a user // for a list of eventIDs. type QueryMembershipAtEventRequest struct { RoomID string @@ -443,7 +446,20 @@ type QueryMembershipAtEventRequest struct { // QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. type QueryMembershipAtEventResponse struct { - // Memberships is a map from eventID to a list of events (if any). Events that - // do not have known state will return an empty array here. - Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` + // Membership is a map from eventID to membership event. Events that + // do not have known state will return a nil event, resulting in a "leave" membership + // when calculating history visibility. + Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"` +} + +// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a +// a room with anymore. This is used to cleanup stale device list entries, where we would +// otherwise keep on trying to get device lists. +type QueryLeftUsersRequest struct { + StaleDeviceListUsers []string `json:"user_ids"` +} + +// QueryLeftUsersResponse is the response to QueryLeftUsersRequest. +type QueryLeftUsersResponse struct { + LeftUsers []string `json:"user_ids"` } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 252be557f..f220560ed 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -108,9 +108,10 @@ func SendInputRoomEvents( } // GetEvent returns the event or nil, even on errors. -func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { +func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var res QueryEventsByIDResponse err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &res) if err != nil { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 329e6af7f..fc61b7f4a 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -30,26 +30,6 @@ import ( "github.com/tidwall/sjson" ) -// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. -type RoomserverInternalAPIDatabase interface { - // Save a given room alias with the room ID it refers to. - // Returns an error if there was a problem talking to the database. - SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error - // Look up the room ID a given alias refers to. - // Returns an error if there was a problem talking to the database. - GetRoomIDForAlias(ctx context.Context, alias string) (string, error) - // Look up all aliases referring to a given room ID. - // Returns an error if there was a problem talking to the database. - GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) - // Remove a given room alias. - // Returns an error if there was a problem talking to the database. - RemoveRoomAlias(ctx context.Context, alias string) error - // Look up the room version for a given room. - GetRoomVersionForRoom( - ctx context.Context, roomID string, - ) (gomatrixserverlib.RoomVersion, error) -} - // SetRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) SetRoomAlias( ctx context.Context, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a3626609..c43b9d049 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,9 +153,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, - DB: r.DB, FSAPI: r.fsAPI, Inputer: r.Inputer, } @@ -193,6 +199,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 03d8bca0b..9defe7945 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -31,7 +31,8 @@ import ( // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -45,16 +46,6 @@ func CheckForSoftFail( return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) } } else { - // Work out if the room exists. - var roomInfo *types.RoomInfo - roomInfo, err = db.RoomInfo(ctx, event.RoomID()) - if err != nil { - return false, fmt.Errorf("db.RoomNID: %w", err) - } - if roomInfo == nil || roomInfo.IsStub() { - return false, nil - } - // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. roomState := state.NewStateResolution(db, roomInfo) @@ -76,7 +67,7 @@ func CheckForSoftFail( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -93,7 +84,8 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -108,7 +100,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, + roomInfo *types.RoomInfo, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -223,7 +216,7 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 7efad7af6..9a70bcc9c 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, eventNIDs) + events, err := db.Events(ctx, info, eventNIDs) if err != nil { return false, err } @@ -157,7 +157,7 @@ func IsInvitePending( // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs types.EventNIDs @@ -177,7 +177,7 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } func LoadEvents( - ctx context.Context, db storage.Database, eventNIDs []types.EventNID, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -242,13 +242,13 @@ func LoadEvents( } func LoadStateEvents( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return LoadEvents(ctx, db, eventNIDs) + return LoadEvents(ctx, db, roomInfo, eventNIDs) } func CheckServerAllowedToSeeEvent( @@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState( return nil, nil } - return LoadStateEvents(ctx, db, filteredEntries) + return LoadStateEvents(ctx, db, info, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function @@ -366,7 +366,7 @@ BFSLoop: next = make([]string, 0) } // Retrieve the events to process from the database. - events, err = db.EventsFromIDs(ctx, front) + events, err = db.EventsFromIDs(ctx, info, front) if err != nil { return resultNIDs, redactEventIDs, err } @@ -467,7 +467,7 @@ func QueryLatestEventsAndState( return err } - stateEvents, err := LoadStateEvents(ctx, db, stateEntries) + stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index aa5c30e44..c056e704c 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" @@ -38,7 +39,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap()) + assert.NoError(t, err) + assert.NotNil(t, roomInfo) + + eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) + assert.NoError(t, err) + assert.Greater(t, eventTypeNID, types.EventTypeNID(0)) + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) + assert.NoError(t, err) + + evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c9..2ec19f010 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -74,11 +76,12 @@ type Inputer struct { Cfg *config.RoomServer Base *base.BaseDendrite ProcessContext *process.ProcessContext - DB storage.Database + DB storage.RoomDatabase NATSClient *nats.Conn JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27f..7c7a902f5 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -27,17 +28,19 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,10 +84,10 @@ func (r *Inputer) processRoomEvent( default: } - span, ctx := opentracing.StartSpanFromContext(ctx, "processRoomEvent") - span.SetTag("room_id", input.Event.RoomID()) - span.SetTag("event_id", input.Event.EventID()) - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "processRoomEvent") + trace.SetTag("room_id", input.Event.RoomID()) + trace.SetTag("event_id", input.Event.EventID()) + defer trace.EndRegion() // Measure how long it takes to process this event. started := time.Now() @@ -163,6 +166,7 @@ func (r *Inputer) processRoomEvent( missingPrev = !input.HasState && len(missingPrevIDs) > 0 } + // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -197,59 +201,8 @@ func (r *Inputer) processRoomEvent( } } - // First of all, check that the auth events of the event are known. - // If they aren't then we will ask the federation API for them. isRejected := false - authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) - } - - // Check if the event is allowed by its auth events. If it isn't then - // we consider the event to be "rejected" — it will still be persisted. var rejectionErr error - if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { - isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) - } - - // Accumulate the auth event NIDs. - authEventIDs := event.AuthEventIDs() - authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownEvents[authEventID]; !ok { - // Unknown auth events only really matter if the event actually failed - // auth. If it passed auth then we can assume that everything that was - // known was sufficient, even if extraneous auth events were specified - // but weren't found. - if isRejected { - if event.StateKey() != nil { - return fmt.Errorf( - "missing auth event %s for state event %s (type %q, state key %q)", - authEventID, event.EventID(), event.Type(), *event.StateKey(), - ) - } else { - return fmt.Errorf( - "missing auth event %s for timeline event %s (type %q)", - authEventID, event.EventID(), event.Type(), - ) - } - } - } else { - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) - } - } - - var softfail bool - if input.Kind == api.KindNew { - // Check that the event passes authentication checks based on the - // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) - if err != nil { - logger.WithError(err).Warn("Error authing soft-failed event") - } - } // At this point we are checking whether we know all of the prev events, and // if we know the state before the prev events. This is necessary before we @@ -311,13 +264,66 @@ func (r *Inputer) processRoomEvent( } } + // Check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + isRejected = true + rejectionErr = err + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) + } + } + + var softfail bool + if input.Kind == api.KindNew && !isCreateEvent { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) + if err != nil { + logger.WithError(err).Warn("Error authing soft-failed event") + } + } + // Get the state before the event so that we can work out if the event was // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. - if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { - historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) + if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { + historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) } @@ -326,17 +332,27 @@ func (r *Inputer) processRoomEvent( } } - // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) - if err != nil { - return fmt.Errorf("updater.StoreEvent: %w", err) + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } } - // if storing this event results in it being redacted then do so. - if !isRejected && redactedEventID == event.EventID() { - if err = eventutil.RedactEvent(redactionEvent, event); err != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) - } + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + + // Store the event. + eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) + if err != nil { + return fmt.Errorf("updater.StoreEvent: %w", err) } // For outliers we can stop after we've stored the event itself as it @@ -367,6 +383,24 @@ func (r *Inputer) processRoomEvent( } } + // if storing this event results in it being redacted then do so. + // we do this after calculating state for this event as we may need to get power levels + var ( + redactedEventID string + redactionEvent *gomatrixserverlib.Event + redactedEvent *gomatrixserverlib.Event + ) + if !isRejected && !isCreateEvent { + resolver := state.NewStateResolution(r.DB, roomInfo) + redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) + if err != nil { + return err + } + if redactedEvent != nil { + redactedEventID = redactedEvent.EventID() + } + } + // We stop here if the event is rejected: We've stored it but won't update // forward extremities or notify downstream components about it. switch { @@ -440,6 +474,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -461,6 +502,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse // nolint:nakedret func (r *Inputer) processStateBefore( ctx context.Context, + roomInfo *types.RoomInfo, input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { @@ -476,7 +518,7 @@ func (r *Inputer) processStateBefore( case input.HasState: // If we're overriding the state then we need to go and retrieve // them from the database. It's a hard error if they are missing. - stateEvents, err := r.DB.EventsFromIDs(ctx, input.StateEventIDs) + stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs) if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } @@ -554,6 +596,7 @@ func (r *Inputer) processStateBefore( // we've failed to retrieve the auth chain altogether (in which case // an error is returned) or we've successfully retrieved them all and // they are now in the database. +// nolint: gocyclo func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, @@ -564,8 +607,8 @@ func (r *Inputer) fetchAuthEvents( known map[string]*types.Event, servers []gomatrixserverlib.ServerName, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "fetchAuthEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "fetchAuthEvents") + defer trace.EndRegion() unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() @@ -574,7 +617,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -660,8 +703,25 @@ nextAuthEvent: logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } + } + + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -692,8 +752,8 @@ func (r *Inputer) calculateAndSetState( event *gomatrixserverlib.Event, isRejected bool, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "calculateAndSetState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "calculateAndSetState") + defer trace.EndRegion() var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) @@ -729,3 +789,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index a223820ef..09db18431 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -23,9 +23,9 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" @@ -59,8 +59,8 @@ func (r *Inputer) updateLatestEvents( rewritesState bool, historyVisibility gomatrixserverlib.HistoryVisibility, ) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "updateLatestEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "updateLatestEvents") + defer trace.EndRegion() var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) @@ -209,8 +209,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { } func (u *latestEventsUpdater) latestState() error { - span, ctx := opentracing.StartSpanFromContext(u.ctx, "processEventWithMissingState") - defer span.Finish() + trace, ctx := internal.StartRegion(u.ctx, "processEventWithMissingState") + defer trace.EndRegion() var err error roomState := state.NewStateResolution(u.updater, u.roomInfo) @@ -329,8 +329,8 @@ func (u *latestEventsUpdater) calculateLatest( newEvent *gomatrixserverlib.Event, newStateAndRef types.StateAtEventAndReference, ) (bool, error) { - span, _ := opentracing.StartSpanFromContext(u.ctx, "calculateLatest") - defer span.Finish() + trace, _ := internal.StartRegion(u.ctx, "calculateLatest") + defer trace.EndRegion() // First of all, get a list of all of the events in our current // set of forward extremities. diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 28a54623b..4028f0b5e 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -18,13 +18,14 @@ import ( "context" "fmt" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) // updateMembership updates the current membership and the invites for each @@ -36,8 +37,8 @@ func (r *Inputer) updateMemberships( updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "updateMemberships") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "updateMemberships") + defer trace.EndRegion() changes := membershipChanges(removed, added) var eventNIDs []types.EventNID @@ -53,7 +54,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, nil, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 03ac2b38d..daef957f1 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -7,16 +7,16 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" - "github.com/sirupsen/logrus" ) type parsedRespState struct { @@ -43,7 +43,7 @@ type missingStateReq struct { log *logrus.Entry virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName - db storage.Database + db storage.RoomDatabase roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier @@ -62,8 +62,8 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "processEventWithMissingState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState") + defer trace.EndRegion() // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the @@ -241,8 +241,8 @@ func (t *missingStateReq) processEventWithMissingState( } func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupResolvedStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupResolvedStateBeforeEvent") + defer trace.EndRegion() type respState struct { // A snapshot is considered trustworthy if it came from our own roomserver. @@ -319,8 +319,8 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateAfterEvent") + defer trace.EndRegion() // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, eventID) @@ -376,8 +376,8 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixs } func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, eventID string) *parsedRespState { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEventLocally") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateAfterEventLocally") + defer trace.EndRegion() var res parsedRespState roomState := state.NewStateResolution(t.db, t.roomInfo) @@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, stateEventNIDs) + stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil @@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even missingEventList = append(missingEventList, evID) } t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil } @@ -449,16 +449,16 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even // the server supports. func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( *parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateBeforeEvent") + defer trace.EndRegion() // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "resolveStatesAndCheck") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "resolveStatesAndCheck") + defer trace.EndRegion() var authEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event @@ -503,8 +503,8 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "getMissingEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "getMissingEvents") + defer trace.EndRegion() logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) @@ -633,8 +633,8 @@ func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserve func (t *missingStateReq) lookupMissingStateViaState( ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (respState *parsedRespState, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupMissingStateViaState") + defer trace.EndRegion() state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { @@ -665,8 +665,8 @@ func (t *missingStateReq) lookupMissingStateViaState( func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaStateIDs") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupMissingStateViaStateIDs") + defer trace.EndRegion() t.log.Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event @@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } @@ -839,12 +839,12 @@ func (t *missingStateReq) createRespStateFromStateIDs( } func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupEvent") + defer trace.EndRegion() if localFirst { // fetch from the roomserver - events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index d42f4e45d..45089bdd1 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type Admin struct { @@ -69,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - memberEvents, err := r.DB.Events(ctx, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } +func (r *Admin) PerformAdminPurgeRoom( + ctx context.Context, + req *api.PerformAdminPurgeRoomRequest, + res *api.PerformAdminPurgeRoomResponse, +) error { + // Validate we actually got a room ID and nothing else + if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Malformed room ID: %s", err), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: err.Error(), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + + return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + { + Type: api.OutputTypePurgeRoom, + PurgeRoom: &api.OutputPurgeRoom{ + RoomID: req.RoomID, + }, + }, + }) +} + func (r *Admin) PerformAdminDownloadState( ctx context.Context, req *api.PerformAdminDownloadStateRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 069f017a9..23862b242 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -23,10 +23,10 @@ import ( "github.com/sirupsen/logrus" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -86,7 +86,7 @@ func (r *Backfiller) PerformBackfill( // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. var loadedEvents []*gomatrixserverlib.Event - loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { return r.backfillViaFederation(ctx, request, response) @@ -122,17 +122,17 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform ctx, req.VirtualHost, requester, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, ) - if err != nil { + // Only return an error if we really couldn't get any events. + if err != nil && len(events) == 0 { logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + // If we got an error but still got events, that's fine, because a server might have returned a 404 (or something) + // but other servers could provide the missing event. + logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) - if err != nil { - return err - } for _, ev := range backfilledEventMap { // now add state for these events @@ -255,6 +255,7 @@ type backfillRequester struct { eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]*gomatrixserverlib.Event historyVisiblity gomatrixserverlib.HistoryVisibility + roomInfo types.RoomInfo } func newBackfillRequester( @@ -319,6 +320,7 @@ FederationHit: FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } res, err := c.StateIDsBeforeEvent(ctx, targetEvent) if err != nil { @@ -394,6 +396,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) if err != nil { @@ -449,14 +452,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -467,7 +470,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -518,11 +521,15 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } eventNIDs := make([]types.EventNID, len(nidMap)) i := 0 + roomNID := b.roomInfo.RoomNID for _, nid := range nidMap { - eventNIDs[i] = nid + eventNIDs[i] = nid.EventNID i++ + if roomNID == 0 { + roomNID = nid.RoomNID + } } - eventsWithNids, err := b.db.Events(ctx, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -539,7 +546,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID @@ -552,7 +559,7 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -565,21 +572,17 @@ func joinEventsFromHistoryVisibility( // Can we see events in the room? canSeeEvents := auth.IsServerAllowed(thisServer, true, events) - visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events)) + visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) return nil, visibility, nil } // get joined members - info, err := db.RoomInfo(ctx, roomID) - if err != nil { - return nil, visibility, nil - } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, false) if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo, joinEventNIDs) return evs, visibility, err } @@ -596,26 +599,47 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs authNids := make([]types.EventNID, len(nidMap)) i := 0 for _, nid := range nidMap { - authNids[i] = nid + authNids[i] = nid.EventNID i++ } - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) + + roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap()) + if err != nil { + logrus.WithError(err).Error("failed to get or create roomNID") + continue + } + roomNID = roomInfo.RoomNID + + eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventType NID") + continue + } + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventStateKey NID") + continue + } + + eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue } + + resolver := state.NewStateResolution(db, roomInfo) + + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), &resolver) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") + continue + } // If storing this event results in it being redacted, then do so. // It's also possible for this event to be a redaction which results in another event being // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = eventToRedact.Headered(ev.RoomVersion) + if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() { + ev = redactedEvent.Headered(ev.RoomVersion) events[j] = ev } backfilledEventMap[ev.EventID()] = types.Event{ diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 29decd363..1fb6eb43a 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -29,7 +29,7 @@ import ( ) type InboundPeeker struct { - DB storage.Database + DB storage.RoomDatabase Inputer *input.Inputer } @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID}) if err != nil { return err } @@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index d593170d4..140ed7c8a 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -291,7 +291,7 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, stateNIDs) + stateEvents, err := db.Events(ctx, info, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c66..fc7ba940c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index fa998e3e1..86f1dfaee 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -109,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID( // mimic the returned values from Synapse res.Message = "You cannot reject this invite" res.Code = 403 - return nil, fmt.Errorf("You cannot reject this invite") + return nil, jsonerror.LeaveServerNoticeError() } } } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 0d97da4d6..4d714be66 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -22,7 +22,6 @@ import ( fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,9 +30,7 @@ type Unpeeker struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI - DB storage.Database - - Inputer *input.Inputer + Inputer *input.Inputer } // PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b1a1f9102..b18a906a2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -25,6 +25,8 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/acls" @@ -32,7 +34,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents( return err } - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents( } authEventIDs = util.UniqueStrings(authEventIDs) - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return fmt.Errorf("getAuthChain: %w", err) } @@ -132,34 +133,46 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryEventsByID implements api.RoomserverInternalAPI +// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine +// which room to use by querying the first events roomID. func (r *Queryer) QueryEventsByID( ctx context.Context, request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) + if len(request.EventIDs) == 0 { + return nil + } + var err error + // We didn't receive a room ID, we need to fetch it first before we can continue. + // This happens for e.g. ` /_matrix/federation/v1/event/{eventId}` + var roomInfo *types.RoomInfo + if request.RoomID == "" { + var eventNIDs map[string]types.EventMetadata + eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]}) + if err != nil { + return err + } + if len(eventNIDs) == 0 { + return nil + } + roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID) + } else { + roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID) + } if err != nil { return err } - - var eventNIDs []types.EventNID - for _, nid := range eventNIDMap { - eventNIDs = append(eventNIDs, nid) + if roomInfo == nil { + return nil } - - events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) + events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs) if err != nil { return err } for _, event := range events { - roomVersion, verr := r.roomVersion(event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion)) } return nil @@ -201,7 +214,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -222,7 +235,8 @@ func (r *Queryer) QueryMembershipAtEvent( request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse, ) error { - response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return fmt.Errorf("unable to get roomInfo: %w", err) @@ -240,7 +254,17 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) } - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) + switch err { + case nil: + return nil + case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event + default: + return err + } + + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -264,7 +288,7 @@ func (r *Queryer) QueryMembershipAtEvent( for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] if !ok || len(stateEntry) == 0 { - response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} + response.Membership[eventID] = nil continue } @@ -272,24 +296,24 @@ func (r *Queryer) QueryMembershipAtEvent( // once. If we have more than one membership event, we need to get the state for each state entry. if canShortCircuit { if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } } else { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) } - res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) - + // Iterate over all membership events we got. Given we only query the membership for + // one user and assuming this user only ever has one membership event associated to + // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - res = append(res, ev.Headered(info.RoomVersion)) + response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) } } - response.Memberships[eventID] = res } return nil @@ -322,7 +346,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -361,14 +385,14 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly) } if err != nil { @@ -416,39 +440,39 @@ func (r *Queryer) QueryServerJoinedToRoom( // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) + serverName gomatrixserverlib.ServerName, + eventID string, +) (allowed bool, err error) { + events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { return } if len(events) == 0 { - response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see - return + return allowed, nil } - roomID := events[0].RoomID() - - inRoomReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: roomID, - ServerName: request.ServerName, - } - inRoomRes := &api.QueryServerJoinedToRoomResponse{} - if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil { - return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err) - } - - info, err := r.DB.RoomInfo(ctx, roomID) + info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID) if err != nil { - return err + return allowed, err } if info == nil || info.IsStub() { - return nil + return allowed, nil } - response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + var isInRoom bool + if r.IsLocalServerName(serverName) || serverName == "" { + isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) + } + } else { + isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err) + } + } + + return helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, info, eventID, serverName, isInRoom, ) - return } // QueryMissingEvents implements api.RoomserverInternalAPI @@ -470,19 +494,22 @@ func (r *Queryer) QueryMissingEvents( eventsToFilter[id] = true } } - events, err := r.DB.EventsFromIDs(ctx, front) + if len(front) == 0 { + return nil // no events to query, give up. + } + events, err := r.DB.EventNIDs(ctx, []string{front[0]}) if err != nil { return err } if len(events) == 0 { return nil // we are missing the events being asked to search from, give up. } - info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID) if err != nil { return err } if info == nil || info.IsStub() { - return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) @@ -490,7 +517,7 @@ func (r *Queryer) QueryMissingEvents( return err } - loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { return err } @@ -533,7 +560,7 @@ func (r *Queryer) QueryStateAndAuthChain( // TODO: this probably means it should be a different query operation... if request.OnlyFetchAuthChain { var authEvents []*gomatrixserverlib.Event - authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs) + authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs) if err != nil { return err } @@ -560,7 +587,7 @@ func (r *Queryer) QueryStateAndAuthChain( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } @@ -615,18 +642,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI return nil, rejected, false, err } - events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries) return events, rejected, false, err } -type eventsFromIDs func(context.Context, []string) ([]types.Event, error) +type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error) // GetAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and // all their auth_events, recursively. The returned set of events contain the // given events. Will *not* error if we don't have all auth events. func GetAuthChain( - ctx context.Context, fn eventsFromIDs, authEventIDs []string, + ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string, ) ([]*gomatrixserverlib.Event, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new @@ -637,7 +664,7 @@ func GetAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := fn(ctx, eventsToFetch) + events, err := fn(ctx, roomInfo, eventsToFetch) if err != nil { return nil, err } @@ -811,6 +838,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS return nil } +func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error { + var err error + res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers) + return err +} + func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") if err != nil { @@ -850,7 +883,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS } func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { - chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) + chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs) if err != nil { return err } @@ -969,7 +1002,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 03627ea97..265f326d4 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { } // EventsFromIDs implements RoomserverInternalAPIEventDB -func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { +func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, @@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } @@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go deleted file mode 100644 index 1bd1b3fb7..000000000 --- a/roomserver/inthttp/client.go +++ /dev/null @@ -1,555 +0,0 @@ -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/gomatrixserverlib" - - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsInputAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/roomserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -const ( - // Alias operations - RoomserverSetRoomAliasPath = "/roomserver/setRoomAlias" - RoomserverGetRoomIDForAliasPath = "/roomserver/GetRoomIDForAlias" - RoomserverGetAliasesForRoomIDPath = "/roomserver/GetAliasesForRoomID" - RoomserverGetCreatorIDForAliasPath = "/roomserver/GetCreatorIDForAlias" - RoomserverRemoveRoomAliasPath = "/roomserver/removeRoomAlias" - - // Input operations - RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" - - // Perform operations - RoomserverPerformInvitePath = "/roomserver/performInvite" - RoomserverPerformPeekPath = "/roomserver/performPeek" - RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" - RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" - RoomserverPerformJoinPath = "/roomserver/performJoin" - RoomserverPerformLeavePath = "/roomserver/performLeave" - RoomserverPerformBackfillPath = "/roomserver/performBackfill" - RoomserverPerformPublishPath = "/roomserver/performPublish" - RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" - RoomserverPerformForgetPath = "/roomserver/performForget" - RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" - RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" - RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState" - - // Query operations - RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" - RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" - RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" - RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" - RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" - RoomserverQueryServerJoinedToRoomPath = "/roomserver/queryServerJoinedToRoomPath" - RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" - RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" - RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" - RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" - RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" - RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" - RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState" - RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser" - RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" - RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" - RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" - RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" - RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" - RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" - RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" -) - -type httpRoomserverInternalAPI struct { - roomserverURL string - httpClient *http.Client - cache caching.RoomVersionCache -} - -// NewRoomserverClient creates a RoomserverInputAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewRoomserverClient( - roomserverURL string, - httpClient *http.Client, - cache caching.RoomVersionCache, -) (api.RoomserverInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverInternalAPIHTTP: httpClient is ") - } - return &httpRoomserverInternalAPI{ - roomserverURL: roomserverURL, - httpClient: httpClient, - cache: cache, - }, nil -} - -// SetFederationInputAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { -} - -// SetAppserviceAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { -} - -// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { -} - -// SetRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) SetRoomAlias( - ctx context.Context, - request *api.SetRoomAliasRequest, - response *api.SetRoomAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// GetRoomIDForAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( - ctx context.Context, - request *api.GetRoomIDForAliasRequest, - response *api.GetRoomIDForAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// GetAliasesForRoomID implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( - ctx context.Context, - request *api.GetAliasesForRoomIDRequest, - response *api.GetAliasesForRoomIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath, - h.httpClient, ctx, request, response, - ) -} - -// RemoveRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) RemoveRoomAlias( - ctx context.Context, - request *api.RemoveRoomAliasRequest, - response *api.RemoveRoomAliasResponse, -) error { - return httputil.CallInternalRPCAPI( - "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath, - h.httpClient, ctx, request, response, - ) -} - -// InputRoomEvents implements RoomserverInputAPI -func (h *httpRoomserverInternalAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - if err := httputil.CallInternalRPCAPI( - "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath, - h.httpClient, ctx, request, response, - ); err != nil { - response.ErrMsg = err.Error() - } - return nil -} - -func (h *httpRoomserverInternalAPI) PerformInvite( - ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformJoin( - ctx context.Context, - request *api.PerformJoinRequest, - response *api.PerformJoinResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformPeek( - ctx context.Context, - request *api.PerformPeekRequest, - response *api.PerformPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformInboundPeek( - ctx context.Context, - request *api.PerformInboundPeekRequest, - response *api.PerformInboundPeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformUnpeek( - ctx context.Context, - request *api.PerformUnpeekRequest, - response *api.PerformUnpeekResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformRoomUpgrade( - ctx context.Context, - request *api.PerformRoomUpgradeRequest, - response *api.PerformRoomUpgradeResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformLeave( - ctx context.Context, - request *api.PerformLeaveRequest, - response *api.PerformLeaveResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformPublish( - ctx context.Context, - request *api.PerformPublishRequest, - response *api.PerformPublishResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( - ctx context.Context, - request *api.PerformAdminEvacuateRoomRequest, - response *api.PerformAdminEvacuateRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminDownloadState( - ctx context.Context, - request *api.PerformAdminDownloadStateRequest, - response *api.PerformAdminDownloadStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminDownloadState", h.roomserverURL+RoomserverPerformAdminDownloadStatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( - ctx context.Context, - request *api.PerformAdminEvacuateUserRequest, - response *api.PerformAdminEvacuateUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryLatestEventsAndState implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath, - h.httpClient, ctx, request, response, - ) -} - -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryEventsByID implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryPublishedRooms( - ctx context.Context, - request *api.QueryPublishedRoomsRequest, - response *api.QueryPublishedRoomsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipForUser implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMembershipForUser( - ctx context.Context, - request *api.QueryMembershipForUserRequest, - response *api.QueryMembershipForUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipsForRoom implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( - ctx context.Context, - request *api.QueryMembershipsForRoomRequest, - response *api.QueryMembershipsForRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMembershipsForRoom implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( - ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - return httputil.CallInternalRPCAPI( - "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryMissingEvents implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryMissingEvents( - ctx context.Context, - request *api.QueryMissingEventsRequest, - response *api.QueryMissingEventsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryStateAndAuthChain implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( - ctx context.Context, - request *api.QueryStateAndAuthChainRequest, - response *api.QueryStateAndAuthChainResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath, - h.httpClient, ctx, request, response, - ) -} - -// PerformBackfill implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryRoomVersionCapabilities implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( - ctx context.Context, - request *api.QueryRoomVersionCapabilitiesRequest, - response *api.QueryRoomVersionCapabilitiesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath, - h.httpClient, ctx, request, response, - ) -} - -// QueryRoomVersionForRoom implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - if roomVersion, ok := h.cache.GetRoomVersion(request.RoomID); ok { - response.RoomVersion = roomVersion - return nil - } - err := httputil.CallInternalRPCAPI( - "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath, - h.httpClient, ctx, request, response, - ) - if err == nil { - h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - } - return err -} - -func (h *httpRoomserverInternalAPI) QueryCurrentState( - ctx context.Context, - request *api.QueryCurrentStateRequest, - response *api.QueryCurrentStateResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryRoomsForUser( - ctx context.Context, - request *api.QueryRoomsForUserRequest, - response *api.QueryRoomsForUserResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryBulkStateContent( - ctx context.Context, - request *api.QueryBulkStateContentRequest, - response *api.QueryBulkStateContentResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QuerySharedUsers( - ctx context.Context, - request *api.QuerySharedUsersRequest, - response *api.QuerySharedUsersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryKnownUsers( - ctx context.Context, - request *api.QueryKnownUsersRequest, - response *api.QueryKnownUsersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryAuthChain( - ctx context.Context, - request *api.QueryAuthChainRequest, - response *api.QueryAuthChainResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( - ctx context.Context, - request *api.QueryServerBannedFromRoomRequest, - response *api.QueryServerBannedFromRoomResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( - ctx context.Context, - request *api.QueryRestrictedJoinAllowedRequest, - response *api.QueryRestrictedJoinAllowedResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpRoomserverInternalAPI) PerformForget( - ctx context.Context, - request *api.PerformForgetRequest, - response *api.PerformForgetResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformForget", h.roomserverURL+RoomserverPerformForgetPath, - h.httpClient, ctx, request, response, - ) - -} - -func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error { - return httputil.CallInternalRPCAPI( - "QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go deleted file mode 100644 index 4d37e90b5..000000000 --- a/roomserver/inthttp/server.go +++ /dev/null @@ -1,206 +0,0 @@ -package inthttp - -import ( - "github.com/gorilla/mux" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/roomserver/api" -) - -// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. -// nolint: gocyclo -func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle( - RoomserverInputRoomEventsPath, - httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents), - ) - - internalAPIMux.Handle( - RoomserverPerformInvitePath, - httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite), - ) - - internalAPIMux.Handle( - RoomserverPerformJoinPath, - httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin), - ) - - internalAPIMux.Handle( - RoomserverPerformLeavePath, - httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave), - ) - - internalAPIMux.Handle( - RoomserverPerformPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek), - ) - - internalAPIMux.Handle( - RoomserverPerformInboundPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek), - ) - - internalAPIMux.Handle( - RoomserverPerformUnpeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek), - ) - - internalAPIMux.Handle( - RoomserverPerformRoomUpgradePath, - httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade), - ) - - internalAPIMux.Handle( - RoomserverPerformPublishPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), - ) - - internalAPIMux.Handle( - RoomserverPerformAdminDownloadStatePath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", r.PerformAdminDownloadState), - ) - - internalAPIMux.Handle( - RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), - ) - - internalAPIMux.Handle( - RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState), - ) - - internalAPIMux.Handle( - RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents), - ) - - internalAPIMux.Handle( - RoomserverQueryEventsByIDPath, - httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID), - ) - - internalAPIMux.Handle( - RoomserverQueryMembershipForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser), - ) - - internalAPIMux.Handle( - RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent), - ) - - internalAPIMux.Handle( - RoomserverQueryMissingEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents), - ) - - internalAPIMux.Handle( - RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain), - ) - - internalAPIMux.Handle( - RoomserverPerformBackfillPath, - httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill), - ) - - internalAPIMux.Handle( - RoomserverPerformForgetPath, - httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom), - ) - - internalAPIMux.Handle( - RoomserverSetRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias), - ) - - internalAPIMux.Handle( - RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias), - ) - - internalAPIMux.Handle( - RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID), - ) - - internalAPIMux.Handle( - RoomserverRemoveRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias), - ) - - internalAPIMux.Handle( - RoomserverQueryCurrentStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState), - ) - - internalAPIMux.Handle( - RoomserverQueryRoomsForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser), - ) - - internalAPIMux.Handle( - RoomserverQueryBulkStateContentPath, - httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent), - ) - - internalAPIMux.Handle( - RoomserverQuerySharedUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers), - ) - - internalAPIMux.Handle( - RoomserverQueryKnownUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers), - ) - - internalAPIMux.Handle( - RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom), - ) - - internalAPIMux.Handle( - RoomserverQueryAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain), - ) - - internalAPIMux.Handle( - RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), - ) - internalAPIMux.Handle( - RoomserverQueryMembershipAtEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), - ) -} diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 1f707735b..5a8d8b570 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -15,23 +15,15 @@ package roomserver import ( - "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal" - "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" - "github.com/sirupsen/logrus" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { - inthttp.AddRoutes(intAPI, router) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +// NewInternalAPI returns a concrete implementation of the internal API. func NewInternalAPI( base *base.BaseDendrite, ) api.RoomserverInternalAPI { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 24b5515e5..a3ca5909e 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -2,27 +2,168 @@ package roomserver_test import ( "context" + "crypto/ed25519" + "reflect" "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + t.Helper() base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return + } + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") +} + +func Test_QueryLeftUsers(t *testing.T) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -48,22 +189,385 @@ func Test_SharedUsers(t *testing.T) { t.Fatalf("failed to send events: %v", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // Query the left users, there should only be "@idontexist:test", + // as Alice and Bob are still joined. + res := &api.QueryLeftUsersResponse{} + leftUserID := "@idontexist:test" + getLeftUsersList := []string{alice.ID, bob.ID, leftUserID} + + testCase := func(rsAPI api.RoomserverInternalAPI) { + if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil { + t.Fatalf("unable to query left users: %v", err) + } + wantCount := 1 + if count := len(res.LeftUsers); count > wantCount { + t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count) + } + if res.LeftUsers[0] != leftUserID { + t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0]) + } } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + + testCase(rsAPI) + }) +} + +func TestPurgeRoom(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, close := mustCreateDatabase(t, dbType) + defer close() + + jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + + // some dummy entries to validate after purging + publishResp := &api.PerformPublishResponse{} + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + t.Fatal(err) } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + if publishResp.Error != nil { + t.Fatal(publishResp.Error) + } + + isPublished, err := db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if !isPublished { + t.Fatalf("room should be published before purging") + } + + aliasResp := &api.SetRoomAliasResponse{} + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + t.Fatal(err) + } + // check the alias is actually there + aliasesResp := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil { + t.Fatal(err) + } + wantAliases := 1 + if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases { + t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases) + } + + // validate the room exists before purging + roomInfo, err := db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo == nil { + t.Fatalf("room does not exist") + } + + // + roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(roomInfo, roomInfo2) { + t.Fatalf("expected roomInfos to be the same, but they aren't") + } + + // remember the roomInfo before purging + existingRoomInfo := roomInfo + + // validate there is an invite for bob + nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID}) + if err != nil { + t.Fatal(err) + } + bobNID, ok := nids[bob.ID] + if !ok { + t.Fatalf("%s does not exist", bob.ID) + } + + _, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + wantInviteCount := 1 + if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + if inviteEventIDs[0] != inviteEvent.EventID() { + t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0]) + } + + // purge the room from the database + purgeResp := &api.PerformAdminPurgeRoomResponse{} + if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + t.Fatal(err) + } + + // wait for all consumers to process the purge event + var sum = 1 + timeout := time.Second * 5 + deadline, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for sum > 0 { + if deadline.Err() != nil { + t.Fatalf("test timed out after %s", timeout) + } + sum = 0 + consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + for x := range consumerCh { + sum += x.NumAckPending + } + time.Sleep(time.Millisecond) + } + + roomInfo, err = db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo != nil { + t.Fatalf("room should not exist after purging: %+v", roomInfo) + } + roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID) + if err == nil { + t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2) + } + + // validation below + + // There should be no invite left + _, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + + if inviteCount := len(inviteEventIDs); inviteCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + + // aliases should be deleted + aliases, err := db.GetAliasesForRoomID(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if aliasCount := len(aliases); aliasCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount) + } + + // published room should be deleted + isPublished, err = db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if isPublished { + t.Fatalf("room should not be published after purging") + } + }) +} + +type fledglingEvent struct { + Type string + StateKey *string + Sender string + RoomID string + Redacts string + Depth int64 + PrevEvents []interface{} +} + +func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { + t.Helper() + roomVer := gomatrixserverlib.RoomVersionV9 + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + Redacts: ev.Redacts, + Depth: ev.Depth, + PrevEvents: ev.PrevEvents, + } + err := eb.SetContent(map[string]interface{}{}) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %v", err) + } + signedEvent, err := eb.Build(time.Now(), "localhost", "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + h := signedEvent.Headered(roomVer) + return h +} + +func TestRedaction(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithSigningServer("notlocalhost", "abc", test.PrivateKeyB)) + + testCases := []struct { + name string + additionalEvents func(t *testing.T, room *test.Room) + wantRedacted bool + }{ + { + name: "can redact own message", + wantRedacted: true, + additionalEvents: func(t *testing.T, room *test.Room) { + redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) + + builderEv := mustCreateEvent(t, fledglingEvent{ + Type: gomatrixserverlib.MRoomRedaction, + Sender: alice.ID, + RoomID: room.ID, + Redacts: redactedEvent.EventID(), + Depth: redactedEvent.Depth() + 1, + PrevEvents: []interface{}{redactedEvent.EventID()}, + }) + room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + }, + }, + { + name: "can redact others message, allowed by PL", + wantRedacted: true, + additionalEvents: func(t *testing.T, room *test.Room) { + redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) + + builderEv := mustCreateEvent(t, fledglingEvent{ + Type: gomatrixserverlib.MRoomRedaction, + Sender: alice.ID, + RoomID: room.ID, + Redacts: redactedEvent.EventID(), + Depth: redactedEvent.Depth() + 1, + PrevEvents: []interface{}{redactedEvent.EventID()}, + }) + room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + }, + }, + { + name: "can redact others message, same server", + wantRedacted: true, + additionalEvents: func(t *testing.T, room *test.Room) { + redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) + + builderEv := mustCreateEvent(t, fledglingEvent{ + Type: gomatrixserverlib.MRoomRedaction, + Sender: bob.ID, + RoomID: room.ID, + Redacts: redactedEvent.EventID(), + Depth: redactedEvent.Depth() + 1, + PrevEvents: []interface{}{redactedEvent.EventID()}, + }) + room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + }, + }, + { + name: "can not redact others message, missing PL", + additionalEvents: func(t *testing.T, room *test.Room) { + redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) + + builderEv := mustCreateEvent(t, fledglingEvent{ + Type: gomatrixserverlib.MRoomRedaction, + Sender: charlie.ID, + RoomID: room.ID, + Redacts: redactedEvent.EventID(), + Depth: redactedEvent.Depth() + 1, + PrevEvents: []interface{}{redactedEvent.EventID()}, + }) + room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + }, + }, + } + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + _, db, close := mustCreateDatabase(t, dbType) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + authEvents := []types.EventNID{} + var roomInfo *types.RoomInfo + var err error + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(charlie.ID)) + + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + + for _, ev := range room.Events() { + roomInfo, err = db.GetOrCreateRoomInfo(ctx, ev.Event) + assert.NoError(t, err) + assert.NotNil(t, roomInfo) + evTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) + assert.NoError(t, err) + + stateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey()) + assert.NoError(t, err) + + eventNID, stateAtEvent, err := db.StoreEvent(ctx, ev.Event, roomInfo, evTypeNID, stateKeyNID, authEvents, false) + assert.NoError(t, err) + if ev.StateKey() != nil { + authEvents = append(authEvents, eventNID) + } + + // Calculate the snapshotNID etc. + plResolver := state.NewStateResolution(db, roomInfo) + stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.Event, false) + assert.NoError(t, err) + + // Update the room + updater, err := db.GetRoomUpdater(ctx, roomInfo) + assert.NoError(t, err) + err = updater.SetState(ctx, eventNID, stateAtEvent.BeforeStateSnapshotNID) + assert.NoError(t, err) + err = updater.Commit() + assert.NoError(t, err) + + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Event, &plResolver) + assert.NoError(t, err) + if redactedEvent != nil { + assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) + } + if ev.Type() == gomatrixserverlib.MRoomRedaction { + nids, err := db.EventNIDs(ctx, []string{ev.Redacts()}) + assert.NoError(t, err) + evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID}) + assert.NoError(t, err) + assert.Equal(t, 1, len(evs)) + assert.Equal(t, tc.wantRedacted, evs[0].Redacted()) + } + } + }) } }) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1cfde5e4b..c3842784e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -41,8 +41,8 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -59,14 +59,55 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta } } +type PowerLevelResolver interface { + Resolve(ctx context.Context, eventID string) (*gomatrixserverlib.PowerLevelContent, error) +} + +func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatrixserverlib.PowerLevelContent, error) { + stateEntries, err := p.LoadStateAtEvent(ctx, eventID) + if err != nil { + return nil, err + } + + wantTuple := types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + } + + var plNID types.EventNID + for _, entry := range stateEntries { + if entry.StateKeyTuple == wantTuple { + plNID = entry.EventNID + break + } + } + if plNID == 0 { + return nil, fmt.Errorf("unable to find power level event") + } + + events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID}) + if err != nil { + return nil, err + } + if len(events) == 0 { + return nil, fmt.Errorf("unable to find power level event") + } + powerlevels, err := events[0].PowerLevels() + if err != nil { + return nil, err + } + + return powerlevels, nil +} + // LoadStateAtSnapshot loads the full state of a room at a particular snapshot. // This is typically the state before an event or the current state of a room. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func (v *StateResolution) LoadStateAtSnapshot( ctx context.Context, stateNID types.StateSnapshotNID, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtSnapshot") + defer trace.EndRegion() stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { @@ -106,8 +147,8 @@ func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtEvent( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtEvent") + defer trace.EndRegion() snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { @@ -128,8 +169,8 @@ func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadMembershipAtEvent( ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID, ) (map[string][]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadMembershipAtEvent") + defer trace.EndRegion() // Get a mapping from snapshotNID -> eventIDs snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs) @@ -197,8 +238,8 @@ func (v *StateResolution) LoadMembershipAtEvent( func (v *StateResolution) LoadStateAtEventForHistoryVisibility( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtEventForHistoryVisibility") + defer trace.EndRegion() snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { @@ -222,8 +263,8 @@ func (v *StateResolution) LoadStateAtEventForHistoryVisibility( func (v *StateResolution) LoadCombinedStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadCombinedStateAfterEvents") + defer trace.EndRegion() stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) for i, state := range prevStates { @@ -297,8 +338,8 @@ func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) DifferenceBetweeenStateSnapshots( ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ) (removed, added []types.StateEntry, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") + defer trace.EndRegion() if oldStateNID == newStateNID { // If the snapshot NIDs are the same then nothing has changed @@ -361,8 +402,8 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples( stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") + defer trace.EndRegion() numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { @@ -378,8 +419,8 @@ func (v *StateResolution) stringTuplesToNumericTuples( ctx context.Context, stringTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateKeyTuple, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.stringTuplesToNumericTuples") + defer trace.EndRegion() eventTypes := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples)) @@ -423,8 +464,8 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples( stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") + defer trace.EndRegion() stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { @@ -474,8 +515,8 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples( prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") + defer trace.EndRegion() numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { @@ -489,8 +530,8 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples( prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") + defer trace.EndRegion() if len(prevStates) == 1 { // Fast path for a single event. @@ -664,8 +705,8 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent( event *gomatrixserverlib.Event, isRejected bool, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") + defer trace.EndRegion() // Load the state at the prev events. prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) @@ -683,8 +724,8 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") + defer trace.EndRegion() metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} @@ -758,8 +799,8 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") + defer trace.EndRegion() state, algorithm, conflictLength, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) @@ -779,8 +820,8 @@ func (v *StateResolution) calculateStateAfterManyEvents( ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.calculateStateAfterManyEvents") + defer trace.EndRegion() var combined []types.StateEntry // Conflict resolution. @@ -834,8 +875,8 @@ func (v *StateResolution) resolveConflicts( ctx context.Context, version gomatrixserverlib.RoomVersion, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflicts") + defer trace.EndRegion() stateResAlgo, err := version.StateResAlgorithm() if err != nil { @@ -861,8 +902,8 @@ func (v *StateResolution) resolveConflictsV1( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflictsV1") + defer trace.EndRegion() // Load the conflicted events conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) @@ -926,8 +967,8 @@ func (v *StateResolution) resolveConflictsV2( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflictsV2") + defer trace.EndRegion() estimate := len(conflicted) + len(notConflicted) eventIDMap := make(map[string]types.StateEntry, estimate) @@ -959,8 +1000,8 @@ func (v *StateResolution) resolveConflictsV2( // For each conflicted event, let's try and get the needed auth events. if err = func() error { - span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents") - defer span.Finish() + loadAuthEventsTrace, sctx := internal.StartRegion(ctx, "StateResolution.loadAuthEvents") + defer loadAuthEventsTrace.EndRegion() loader := authEventLoader{ v: v, @@ -975,7 +1016,7 @@ func (v *StateResolution) resolveConflictsV2( // Store the newly found auth events in the auth set for this event. var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents) + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents) if err != nil { return err } @@ -1004,8 +1045,8 @@ func (v *StateResolution) resolveConflictsV2( // Resolve the conflicts. resolvedEvents := func() []*gomatrixserverlib.Event { - span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2") - defer span.Finish() + resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2") + defer resolvedTrace.EndRegion() return gomatrixserverlib.ResolveStateConflictsV2( conflictedEvents, @@ -1077,8 +1118,8 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents") + defer trace.EndRegion() result := make([]*gomatrixserverlib.Event, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) @@ -1091,7 +1132,7 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, eventNIDs) + events, err := v.db.Events(ctx, v.roomInfo, eventNIDs) if err != nil { return nil, nil, err } @@ -1120,7 +1161,7 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event, + ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() @@ -1155,7 +1196,7 @@ func (l *authEventLoader) loadAuthEvents( // If we need to get events from the database, go and fetch // those now. if len(l.lookupFromDB) > 0 { - eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB) + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB) if err != nil { return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a13f4f04d..a577f4650 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -29,6 +30,7 @@ type Database interface { SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // Store the room state at an event in the database AddState( ctx context.Context, @@ -69,15 +71,12 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) - // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, - isRejected bool, - ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -87,7 +86,7 @@ type Database interface { EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) // Look up the numeric IDs for a list of events. // Returns an error if there was a problem talking to the database. - EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error) // Set the state at an event. FIXME TODO: "at" SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error // Lookup the event IDs for a batch of event numeric IDs. @@ -138,7 +137,7 @@ type Database interface { // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. - EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. @@ -172,5 +171,63 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) + PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error + + // GetMembershipForHistoryVisibility queries the membership events for the given eventIDs. + // Returns a map from (input) eventID -> membership event. If no membership event is found, returns an empty event, resulting in + // a membership of "leave" when calculating history visibility. + GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) + GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) +} + +type RoomDatabase interface { + EventDatabase + // RoomInfo returns room information for the given room ID, or nil if there is no room. + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) + // IsEventRejected returns true if the event is known and rejected. + IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) + MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) + GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type EventDatabase interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error) + SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error + // (nil if there was nothing to do) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go new file mode 100644 index 000000000..add66446b --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpPulishedAppservicePrimaryKey(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published RENAME CONSTRAINT roomserver_published_pkey TO roomserver_published_pkeyold; +CREATE UNIQUE INDEX roomserver_published_pkey ON roomserver_published (room_id, appservice_id, network_id); +ALTER TABLE roomserver_published DROP CONSTRAINT roomserver_published_pkeyold; +ALTER TABLE roomserver_published ADD PRIMARY KEY USING INDEX roomserver_published_pkey;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 9b5ed6eda..c935608a5 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -69,6 +69,9 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( auth_event_nids BIGINT[] NOT NULL, is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); + +-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility) +CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5); ` const insertEventSQL = "" + @@ -137,10 +140,10 @@ const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" const bulkSelectEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1)" const bulkSelectUnsentEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -517,20 +520,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, false) } // BulkSelectEventNIDs returns a map from string event ID to numeric event ID // only for events that haven't already been sent to the roomserver output. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, true) } // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) { var stmt *sql.Stmt if onlyUnsent { stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) @@ -542,14 +545,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") - results := make(map[string]types.EventNID, len(eventIDs)) + results := make(map[string]types.EventMetadata, len(eventIDs)) var eventID string var eventNID int64 + var roomNID int64 for rows.Next() { - if err = rows.Scan(&eventID, &eventNID); err != nil { + if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil { return nil, err } - results[eventID] = types.EventNID(eventNID) + results[eventID] = types.EventMetadata{ + EventNID: types.EventNID(eventNID), + RoomNID: types.RoomNID(roomNID), + } } return results, rows.Err() } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0150534e1..d774b7892 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid = ANY($2) +` + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -174,6 +181,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func CreateMembershipTable(db *sql.DB) error { @@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, }.Prepare(db) } +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt) + rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 61caccb0e..eca81d81f 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -65,10 +65,16 @@ func CreatePublishedTable(db *sql.DB) error { return err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "roomserver: published appservice", - Up: deltas.UpPulishedAppservice, - }) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: published appservice", + Up: deltas.UpPulishedAppservice, + }, + { + Version: "roomserver: published appservice pkey", + Up: deltas.UpPulishedAppservicePrimaryKey, + }, + }...) return m.Up(context.Background()) } diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go new file mode 100644 index 000000000..efba439bd --- /dev/null +++ b/roomserver/storage/postgres/purge_statements.go @@ -0,0 +1,133 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid = ANY(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" + + " SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateBlockEntriesSQL = "" + + "DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" + + " SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" + + ")" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateBlockEntriesStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt +} + +func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { + s := &purgeStatements{} + + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + {&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a00c026f4..0e83cfc25 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,10 +21,10 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = ` AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); ` +// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events. +// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON. +const bulkSelectMembershipForHistoryVisibilitySQL = ` +SELECT re.event_id, re2.event_id, rej.event_json +FROM roomserver_events re +LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid +CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid) +LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid +CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid) +JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1 +LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid +WHERE re.event_id = ANY($2) + +` + type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt - bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, + {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL}, }.Prepare(db) } @@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( } return results, rows.Err() } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + var evJson []byte + var eventID string + var membershipEventID string + + knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + + for rows.Next() { + if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil { + return nil, err + } + if len(evJson) == 0 { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + continue + } + // If we already know this event, don't try to marshal the json again + if ev, ok := knownEvents[membershipEventID]; ok { + result[eventID] = ev + continue + } + event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion) + if err != nil { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + // not fatal + continue + } + he := event.Headered(roomInfo.RoomVersion) + result[eventID] = he + knownEvents[membershipEventID] = he + } + return result, rows.Err() +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 23a5f79eb..d98a5cf97 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -189,23 +189,33 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db) + if err != nil { + return err + } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - EventsTable: events, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventJSONTable: eventJSON, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index cc880a6c8..dc1db0825 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -116,10 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent }) } -func (u *RoomUpdater) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - return u.d.events(ctx, u.txn, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -197,12 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) -} - -func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter) } // IsReferenced implements types.RoomRecentEventsUpdater diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 64e553a9c..d40ef4b63 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -27,6 +28,23 @@ import ( const redactionsArePermanent = true type Database struct { + DB *sql.DB + EventDatabase + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) +} + +// EventDatabase contains all tables needed to work with events +type EventDatabase struct { DB *sql.DB Cache caching.RoomServerCaches Writer sqlutil.Writer @@ -34,43 +52,54 @@ type Database struct { EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published RedactionsTable tables.Redactions - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { return true } -func (d *Database) EventTypeNIDs( +func (d *Database) GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) +} + +func (d *EventDatabase) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { return d.eventTypeNIDs(ctx, nil, eventTypes) } -func (d *Database) eventTypeNIDs( +func (d *EventDatabase) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) - if err != nil { - return nil, err + // first try the cache + fetchEventTypes := make([]string, 0, len(eventTypes)) + for _, eventType := range eventTypes { + eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType) + if ok { + result[eventType] = eventTypeNID + continue + } + fetchEventTypes = append(fetchEventTypes, eventType) } - for eventType, nid := range nids { - result[eventType] = nid + if len(fetchEventTypes) > 0 { + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, fetchEventTypes) + if err != nil { + return nil, err + } + for eventType, nid := range nids { + result[eventType] = nid + d.Cache.StoreEventTypeKey(nid, eventType) + } } return result, nil } -func (d *Database) EventStateKeys( +func (d *EventDatabase) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) @@ -82,38 +111,56 @@ func (d *Database) EventStateKeys( fetch = append(fetch, nid) } } - fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) - if err != nil { - return nil, err - } - for nid, key := range fromDB { - result[nid] = key - d.Cache.StoreEventStateKey(nid, key) + if len(fetch) > 0 { + fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) + if err != nil { + return nil, err + } + for nid, key := range fromDB { + result[nid] = key + d.Cache.StoreEventStateKey(nid, key) + } } return result, nil } -func (d *Database) EventStateKeyNIDs( +func (d *EventDatabase) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) } -func (d *Database) eventStateKeyNIDs( +func (d *EventDatabase) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) eventStateKeys = util.UniqueStrings(eventStateKeys) - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) - if err != nil { - return nil, err + // first ask the cache about these keys + fetchEventStateKeys := make([]string, 0, len(eventStateKeys)) + for _, eventStateKey := range eventStateKeys { + eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) + if ok { + result[eventStateKey] = eventStateKeyNID + continue + } + fetchEventStateKeys = append(fetchEventStateKeys, eventStateKey) } - for eventStateKey, nid := range nids { - result[eventStateKey] = nid + + if len(fetchEventStateKeys) > 0 { + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, fetchEventStateKeys) + if err != nil { + return nil, err + } + for eventStateKey, nid := range nids { + result[eventStateKey] = nid + d.Cache.StoreEventStateKey(nid, eventStateKey) + } } + // We received some nids, but are still missing some, work out which and create them if len(eventStateKeys) > len(result) { var nid types.EventStateKeyNID + var err error err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { for _, eventStateKey := range eventStateKeys { if _, ok := result[eventStateKey]; ok { @@ -135,7 +182,7 @@ func (d *Database) eventStateKeyNIDs( return result, nil } -func (d *Database) StateEntriesForEventIDs( +func (d *EventDatabase) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) @@ -174,6 +221,17 @@ func (d *Database) stateEntriesForTuples( return lists, nil } +func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) { + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID}) + if err != nil { + return nil, err + } + if len(roomIDs) == 0 { + return nil, fmt.Errorf("room does not exist") + } + return d.roomInfo(ctx, nil, roomIDs[0]) +} + func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { return d.roomInfo(ctx, nil, roomID) } @@ -253,9 +311,9 @@ func (d *Database) addState( return } -func (d *Database) EventNIDs( +func (d *EventDatabase) EventNIDs( ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { +) (map[string]types.EventMetadata, error) { return d.eventNIDs(ctx, nil, eventIDs, NoFilter) } @@ -266,9 +324,9 @@ const ( FilterUnsentOnly UnsentFilter = true ) -func (d *Database) eventNIDs( +func (d *EventDatabase) eventNIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, -) (map[string]types.EventNID, error) { +) (map[string]types.EventMetadata, error) { switch filter { case FilterUnsentOnly: return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) @@ -279,7 +337,7 @@ func (d *Database) eventNIDs( } } -func (d *Database) SetState( +func (d *EventDatabase) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -287,19 +345,19 @@ func (d *Database) SetState( }) } -func (d *Database) StateAtEventIDs( +func (d *EventDatabase) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } -func (d *Database) SnapshotNIDFromEventID( +func (d *EventDatabase) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { return d.snapshotNIDFromEventID(ctx, nil, eventID) } -func (d *Database) snapshotNIDFromEventID( +func (d *EventDatabase) snapshotNIDFromEventID( ctx context.Context, txn *sql.Tx, eventID string, ) (types.StateSnapshotNID, error) { _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) @@ -312,17 +370,17 @@ func (d *Database) snapshotNIDFromEventID( return stateNID, err } -func (d *Database) EventIDs( +func (d *EventDatabase) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) +func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { +func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err @@ -330,10 +388,10 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st var nids []types.EventNID for _, nid := range nidMap { - nids = append(nids, nid) + nids = append(nids, nid.EventNID) } - return d.events(ctx, txn, nids) + return d.events(ctx, txn, roomInfo, nids) } func (d *Database) LatestEventIDs( @@ -472,15 +530,17 @@ func (d *Database) GetInvitesForUser( return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - return d.events(ctx, nil, eventNIDs) +func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return d.events(ctx, nil, roomInfo, eventNIDs) } -func (d *Database) events( - ctx context.Context, txn *sql.Tx, inputEventNIDs types.EventNIDs, +func (d *EventDatabase) events( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { + if roomInfo == nil { // this should never happen + return nil, fmt.Errorf("unable to parse events without roomInfo") + } + sort.Sort(inputEventNIDs) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) @@ -507,47 +567,21 @@ func (d *Database) events( if !redactionsArePermanent { d.applyRedactions(results) } + return results, nil } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) + eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) - if err != nil { - return nil, err - } - uniqueRoomNIDs := make(map[types.RoomNID]struct{}) - for _, n := range roomNIDs { - uniqueRoomNIDs[n] = struct{}{} - } - roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) - fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs)) - for n := range uniqueRoomNIDs { - if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - roomVersions[n] = roomVersion - continue - } - } - fetchNIDList = append(fetchNIDList, n) - } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) - if err != nil { - return nil, err - } - for n, v := range dbRoomVersions { - roomVersions[n] = v - } + for _, eventJSON := range eventJSONs { - roomNID := roomNIDs[eventJSON.EventNID] - roomVersion := roomVersions[roomNID] + redacted := gjson.GetBytes(eventJSON.EventJSON, "unsigned.redacted_because").Exists() events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( - eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, + eventIDs[eventJSON.EventNID], eventJSON.EventJSON, redacted, roomInfo.RoomVersion, ) if err != nil { return nil, err @@ -617,77 +651,85 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } -func (d *Database) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. +func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { + // Get the default room version. If the client doesn't supply a room_version + // then we will use our configured default to create the room. + // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom + // Note that the below logic depends on the m.room.create event being the + // first event that is persisted to the database when creating or joining a + // room. + var roomVersion gomatrixserverlib.RoomVersion + if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { + return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) + } + if roomVersion == "" { + rv, ok := d.Cache.GetRoomVersion(event.RoomID()) + if ok { + roomVersion = rv + } + } + var roomNID types.RoomNID + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) + if err != nil { + return err + } + return nil + }) + if roomVersion != "" { + d.Cache.StoreRoomVersion(event.RoomID(), roomVersion) + } + return &types.RoomInfo{ + RoomVersion: roomVersion, + RoomNID: roomNID, + }, err } -func (d *Database) storeEvent( - ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - redactionEvent *gomatrixserverlib.Event - redactedEventID string - err error - ) - var txn *sql.Tx - if updater != nil && updater.txn != nil { - txn = updater.txn - } - // First writer is with a database-provided transaction, so that NIDs are assigned - // globally outside of the updater context, to help avoid races. +func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) - } - - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return fmt.Errorf("d.assignRoomNID: %w", err) - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { return fmt.Errorf("d.assignEventTypeNID: %w", err) } + return nil + }) + return eventTypeNID, err +} - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return fmt.Errorf("d.assignStateKeyNID: %w", err) - } +func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (eventStateKeyNID types.EventStateKeyNID, err error) { + if eventStateKey == nil { + return 0, nil + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return fmt.Errorf("d.assignStateKeyNID: %w", err) } - return nil }) if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, err } - // Second writer is using the database-provided transaction, probably from the - // room updater, for easy roll-back if required. - err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + + return eventStateKeyNID, nil +} + +func (d *EventDatabase) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.StateAtEvent, error) { + var ( + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventNID, stateNID, err = d.EventsTable.InsertEvent( ctx, txn, - roomNID, + roomInfo.RoomNID, eventTypeNID, eventStateKeyNID, event.EventID(), @@ -710,16 +752,26 @@ func (d *Database) storeEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if !isRejected { // ignore rejected redaction events - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) - if err != nil { - return fmt.Errorf("d.handleRedactions: %w", err) + + if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of + // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This + // function only does SELECTs though so the created txn (at this point) is just a read txn like + // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater + // to do writes however then this will need to go inside `Writer.Do`. + + // The following is a copy of RoomUpdater.StorePreviousEvents + for _, ref := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } } + return nil }) if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err) } // We should attempt to update the previous events table with any @@ -727,35 +779,8 @@ func (d *Database) storeEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of - // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This - // function only does SELECTs though so the created txn (at this point) is just a read txn like - // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater - // to do writes however then this will need to go inside `Writer.Do`. - succeeded := false - if updater == nil { - var roomInfo *types.RoomInfo - roomInfo, err = d.roomInfo(ctx, txn, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } - updater, err = d.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) - } - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - } - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) - } - succeeded = true - } - return eventNID, roomNID, types.StateAtEvent{ + return eventNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, StateEntry: types.StateEntry{ StateKeyTuple: types.StateKeyTuple{ @@ -764,7 +789,7 @@ func (d *Database) storeEvent( }, EventNID: eventNID, }, - }, redactionEvent, redactedEventID, err + }, err } func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { @@ -807,6 +832,10 @@ func (d *Database) MissingAuthPrevEvents( func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID) + if ok { + return roomNID, nil + } // Check if we already have a numeric ID in the database. roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { @@ -817,12 +846,21 @@ func (d *Database) assignRoomNID( roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) } } - return roomNID, err + if err != nil { + return 0, err + } + d.Cache.StoreRoomServerRoomID(roomNID, roomID) + d.Cache.StoreRoomVersion(roomID, roomVersion) + return roomNID, nil } func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { + eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType) + if ok { + return eventTypeNID, nil + } // Check if we already have a numeric ID in the database. eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { @@ -833,12 +871,20 @@ func (d *Database) assignEventTypeNID( eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) } } - return eventTypeNID, err + if err != nil { + return 0, err + } + d.Cache.StoreEventTypeKey(eventTypeNID, eventType) + return eventTypeNID, nil } -func (d *Database) assignStateKeyNID( +func (d *EventDatabase) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { + eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) + if ok { + return eventStateKeyNID, nil + } // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { @@ -849,6 +895,7 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } + d.Cache.StoreEventStateKey(eventStateKeyNID, eventStateKey) return eventStateKeyNID, err } @@ -875,7 +922,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( return roomVersion, err } -// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec: +// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // These cases are: @@ -890,95 +937,109 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // to cross-reference with other tables when loading. // -// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. -func (d *Database) handleRedactions( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, -) (*gomatrixserverlib.Event, string, error) { - var err error - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil - if isRedactionEvent { - // an event which redacts itself should be ignored - if event.EventID() == event.Redacts() { - return nil, "", nil +// Returns the redaction event and the redacted event if this call resulted in a redaction. +func (d *EventDatabase) MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, +) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) { + var ( + redactionEvent, redactedEvent *types.Event + err error + validated bool + ignoreRedaction bool + ) + + wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + if isRedactionEvent { + // an event which redacts itself should be ignored + if event.EventID() == event.Redacts() { + return nil + } + + err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ + Validated: false, + RedactionEventID: event.EventID(), + RedactsEventID: event.Redacts(), + }) + if err != nil { + return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + } } - err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ - Validated: false, - RedactionEventID: event.EventID(), - RedactsEventID: event.Redacts(), - }) + redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event) + switch { + case err != nil: + return fmt.Errorf("d.loadRedactionPair: %w", err) + case validated || redactedEvent == nil || redactionEvent == nil: + // we've seen this redaction before or there is nothing to redact + return nil + case redactedEvent.RoomID() != redactionEvent.RoomID(): + // redactions across rooms aren't allowed + ignoreRedaction = true + return nil + } + + _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) + _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + var powerlevels *gomatrixserverlib.PowerLevelContent + powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { - return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + return err } - } - redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event) - if err != nil { - return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err) - } - if validated || redactedEvent == nil || redactionEvent == nil { - // we've seen this redaction before or there is nothing to redact - return nil, "", nil - } - if redactedEvent.RoomID() != redactionEvent.RoomID() { - // redactions across rooms aren't allowed - return nil, "", nil - } + switch { + case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. + case sender1 == sender2: + // 2. The domain of the redaction event’s sender matches that of the original event’s sender. + default: + ignoreRedaction = true + return nil + } - // Get the power level from the database, so we can verify the user is allowed to redact the event - powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "") - if err != nil { - return nil, "", fmt.Errorf("d.GetStateEvent: %w", err) - } - if powerLevels == nil { - return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID()) - } - pl, err := powerLevels.PowerLevels() - if err != nil { - return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err) - } + // mark the event as redacted + if redactionsArePermanent { + redactedEvent.Redact() + } - redactUser := pl.UserLevel(redactionEvent.Sender()) - switch { - case redactUser >= pl.Redact: - // The power level of the redaction event’s sender is greater than or equal to the redact level. - case redactedEvent.Sender() == redactionEvent.Sender(): - // The domain of the redaction event’s sender matches that of the original event’s sender. - default: - return nil, "", nil - } + err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) + if err != nil { + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // NOTSPEC: sytest relies on this unspecced field existing :( + err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) + if err != nil { + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // overwrite the eventJSON table + err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) + if err != nil { + return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) + } - // mark the event as redacted - if redactionsArePermanent { - redactedEvent.Redact() - } + err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) + if err != nil { + return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) + } - err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // NOTSPEC: sytest relies on this unspecced field existing :( - err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // overwrite the eventJSON table - err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) - if err != nil { - return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) - } + // We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be + // certain that the cached entry actually is updated, since ristretto is eventual-persistent. + d.Cache.InvalidateRoomServerEvent(redactedEvent.EventNID) - err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) - if err != nil { - err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) + return nil + }) + if wErr != nil { + return nil, nil, err } - - return redactionEvent.Event, redactedEvent.EventID(), err + if ignoreRedaction || redactionEvent == nil || redactedEvent == nil { + return nil, nil, nil + } + return redactionEvent.Event, redactedEvent.Event, nil } // loadRedactionPair returns both the redaction event and the redacted event, else nil. -func (d *Database) loadRedactionPair( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, +func (d *EventDatabase) loadRedactionPair( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo @@ -1010,16 +1071,16 @@ func (d *Database) loadRedactionPair( } if isRedactionEvent { - redactedEvent = d.loadEvent(ctx, info.RedactsEventID) + redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID) } else { - redactionEvent = d.loadEvent(ctx, info.RedactionEventID) + redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID) } return redactionEvent, redactedEvent, info.Validated, nil } // applyRedactions will redact events that have an `unsigned.redacted_because` field. -func (d *Database) applyRedactions(events []types.Event) { +func (d *EventDatabase) applyRedactions(events []types.Event) { for i := range events { if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() { events[i].Redact() @@ -1028,7 +1089,7 @@ func (d *Database) applyRedactions(events []types.Event) { } // loadEvent loads a single event or returns nil on any problems/missing event -func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { +func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event { nids, err := d.EventNIDs(ctx, []string{eventID}) if err != nil { return nil @@ -1036,7 +1097,7 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, []types.EventNID{nids[eventID]}) + evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } @@ -1082,7 +1143,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1147,7 +1208,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s // Same as GetStateEvent but returns all matching state events with this event type. Returns no error // if there are no events with this event type. func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1278,7 +1339,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) // TODO: This feels like this is going to be really slow... for _, roomID := range roomIDs { - roomInfo, err2 := d.RoomInfo(ctx, roomID) + roomInfo, err2 := d.roomInfo(ctx, nil, roomID) if err2 != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) } @@ -1365,6 +1426,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ return result, nil } +// GetLeftUsers calculates users we (the server) don't share a room with anymore. +func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) { + // Get the userNID for all users with a stale device list + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs) + if err != nil { + return nil, err + } + + userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)) + userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap)) + // Create a map from userNID -> userID + for userID, nid := range stateKeyNIDMap { + userNIDs = append(userNIDs, nid) + userNIDtoUserID[nid] = userID + } + + // Get all users whose membership is still join, knock or invite. + stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs) + if err != nil { + return nil, err + } + + // Remove joined users from the "user with stale devices" list, which contains left AND joined users + for _, joinedUser := range stillJoinedUsersNIDs { + delete(userNIDtoUserID, joinedUser) + } + + // The users still in our userNIDtoUserID map are the users we don't share a room with anymore, + // and the return value we are looking for. + leftUsers := make([]string, 0, len(userNIDtoUserID)) + for _, userID := range userNIDtoUserID { + leftUsers = append(leftUsers, userID) + } + + return leftUsers, nil +} + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) @@ -1408,16 +1506,37 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +// PurgeRoom removes all information about a given room from the roomserver. +// For large rooms this operation may take a considerable amount of time. +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("room %s does not exist", roomID) + } + return fmt.Errorf("failed to lock the room: %w", err) + } + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) + }) +} + func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // un-publish old room - if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { - return fmt.Errorf("failed to unpublish room: %w", err) + published, err := d.PublishedTable.SelectPublishedFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get published room: %w", err) } - // publish new room - if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { - return fmt.Errorf("failed to publish room: %w", err) + if published { + // un-publish old room + if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } } // Migrate any existing room aliases diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go new file mode 100644 index 000000000..684e80b8f --- /dev/null +++ b/roomserver/storage/shared/storage_test.go @@ -0,0 +1,103 @@ +package shared_test + +import ( + "context" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) { + t.Helper() + + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} + + db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + + var membershipTable tables.Membership + var stateKeyTable tables.EventStateKeys + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = postgres.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = postgres.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = sqlite3.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = sqlite3.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + } + assert.NoError(t, err) + + cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) + + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + + return &shared.Database{ + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + Cache: cache, + }, func() { + err := base.Close() + assert.NoError(t, err) + clearDB() + err = db.Close() + assert.NoError(t, err) + } +} + +func Test_GetLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // Create dummy entries + for _, user := range []*test.User{alice, bob, charlie} { + nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID) + assert.NoError(t, err) + err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true) + assert.NoError(t, err) + // We must update the membership with a non-zero event NID or it will get filtered out in later queries + membershipNID := tables.MembershipStateLeaveOrBan + if user == alice { + membershipNID = tables.MembershipStateJoin + } + _, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false) + assert.NoError(t, err) + } + + // Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership + expectedUserIDs := []string{bob.ID, charlie.ID} + leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID}) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserIDs, leftUsers) + }) +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index f39b9902d..aacf4bc9a 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -110,10 +110,10 @@ const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" const bulkSelectEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id IN ($1)" const bulkSelectUnsentEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -572,20 +572,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, false) } // BulkSelectEventNIDs returns a map from string event ID to numeric event ID // only for events that haven't already been sent to the roomserver output. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) { return s.bulkSelectEventNID(ctx, txn, eventIDs, true) } // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { @@ -609,14 +609,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") - results := make(map[string]types.EventNID, len(eventIDs)) + results := make(map[string]types.EventMetadata, len(eventIDs)) var eventID string var eventNID int64 + var roomNID int64 for rows.Next() { - if err = rows.Scan(&eventID, &eventNID); err != nil { + if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil { return nil, err } - results[eventID] = types.EventNID(eventNID) + results[eventID] = types.EventMetadata{ + EventNID: types.EventNID(eventNID), + RoomNID: types.RoomNID(roomNID), + } } return results, nil } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index cd149f0ed..8a60b359f 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" + const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid IN ($2) +` + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -149,6 +156,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime } func CreateMembershipTable(db *sql.DB) error { @@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership( ) return err } + +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1) + + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed") + + params := make([]any, len(targetUserNIDs)+1) + params[0] = tables.MembershipStateLeaveOrBan + for i := range targetUserNIDs { + params[i+1] = targetUserNIDs[i] + } + + stmt = sqlutil.TxStmt(txn, stmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go new file mode 100644 index 000000000..c7b4d27a5 --- /dev/null +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid IN (" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids IN(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt + stateSnapshot *stateSnapshotStatements +} + +func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) { + s := &purgeStatements{stateSnapshot: stateSnapshot} + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + //{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} + +func (s *purgeStatements) purgeStateBlocks( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) error { + // Get all stateBlockNIDs + stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } + params := make([]interface{}, len(stateBlockNIDs)) + seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs)) + // dedupe NIDs + for k, v := range stateBlockNIDs { + if _, ok := seenNIDs[v]; ok { + continue + } + params[k] = v + seenNIDs[v] = struct{}{} + } + + query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" + return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..7556b3461 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" + const bulkSelectRoomNIDsSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 4e67d4da1..ae8181cfa 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error { return err } -func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) { s := &stateBlockStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 73827522c..e57e1a4bf 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -62,10 +63,14 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +const selectStateBlockNIDsForRoomNID = "" + + "SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1" + type stateSnapshotStatements struct { db *sql.DB insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt + selectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -73,7 +78,7 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{ db: db, } @@ -81,6 +86,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID}, }.Prepare(db) } @@ -146,3 +152,33 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ) ([]types.EventNID, error) { return nil, tables.OptimisationNotSupportedError } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return nil, tables.OptimisationNotSupportedError +} + +func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.StateBlockNID, error) { + var res []types.StateBlockNID + rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID) + if err != nil { + return res, nil + } + defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed") + + var stateBlockNIDs []types.StateBlockNID + var stateBlockNIDsJSON string + for rows.Next() { + if err = rows.Scan(&stateBlockNIDsJSON); err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil { + return nil, err + } + + res = append(res, stateBlockNIDs...) + } + + return res, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 01c3f879c..2adedd2d8 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -197,24 +197,35 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db, stateSnapshot) + if err != nil { + return err + } + d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetRoomUpdaterFn: d.GetRoomUpdater, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 50d27c756..4ce2a9c4e 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -63,8 +63,8 @@ type Events interface { BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) - BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) + BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error) @@ -73,6 +73,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error @@ -90,6 +91,10 @@ type StateSnapshot interface { // which users are in a room faster than having to load the entire room state. In the // case of SQLite, this will return tables.OptimisationNotSupportedError. BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error) + + BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) } type StateBlock interface { @@ -144,6 +149,7 @@ type Membership interface { SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error + SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } type Published interface { @@ -172,6 +178,12 @@ type Redactions interface { MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } +type Purge interface { + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c9541d9d2..c4524ee44 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) { knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) assert.NoError(t, err) assert.Equal(t, 1, len(knownUsers)) + + // get users we share a room with, given their userNID + joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs) + assert.NoError(t, err) + // Only userNIDs[0] is actually joined, so we only expect this userNID + assert.Equal(t, userNIDs[:1], joinedUsers) }) } diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go index b2e59377d..c7c991b20 100644 --- a/roomserver/storage/tables/state_snapshot_table_test.go +++ b/roomserver/storage/tables/state_snapshot_table_test.go @@ -29,6 +29,8 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables. assert.NoError(t, err) err = postgres.CreateEventsTable(db) assert.NoError(t, err) + err = postgres.CreateEventJSONTable(db) + assert.NoError(t, err) err = postgres.CreateStateBlockTable(db) assert.NoError(t, err) // ... and then the snapshot table itself diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f40980994..6401a94be 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -38,6 +38,11 @@ type EventNID int64 // RoomNID is a numeric ID for a room. type RoomNID int64 +type EventMetadata struct { + EventNID EventNID + RoomNID RoomNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64 diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 729d00a80..c40d8e0f7 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV9 + return gomatrixserverlib.RoomVersionV10 } // RoomVersions returns a map of all known room versions to this diff --git a/setup/base/base.go b/setup/base/base.go index e5a6a3c87..dfe48ff3c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -15,12 +15,16 @@ package base import ( + "bytes" "context" - "crypto/tls" "database/sql" + "embed" "encoding/json" + "errors" "fmt" + "html/template" "io" + "io/fs" "net" "net/http" _ "net/http/pprof" @@ -35,8 +39,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" @@ -50,21 +52,14 @@ import ( "github.com/sirupsen/logrus" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" - federationAPI "github.com/matrix-org/dendrite/federationapi/api" - federationIntHTTP "github.com/matrix-org/dendrite/federationapi/inthttp" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" - keyinthttp "github.com/matrix-org/dendrite/keyserver/inthttp" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - userapi "github.com/matrix-org/dendrite/userapi/api" - userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" ) +//go:embed static/*.gotmpl +var staticContent embed.FS + // BaseDendrite is a base for creating new instances of dendrite. It parses // command line flags and config, and exposes methods for creating various // resources. All errors are handled by logging then exiting, so all methods @@ -72,19 +67,16 @@ import ( // Must be closed when shutting down. type BaseDendrite struct { *process.ProcessContext - componentName string tracerCloser io.Closer PublicClientAPIMux *mux.Router PublicFederationAPIMux *mux.Router PublicKeyAPIMux *mux.Router PublicMediaAPIMux *mux.Router PublicWellKnownAPIMux *mux.Router - InternalAPIMux *mux.Router + PublicStaticMux *mux.Router DendriteAdminMux *mux.Router SynapseAdminMux *mux.Router NATS *jetstream.NATSInstance - UseHTTPAPIs bool - apiHttpClient *http.Client Cfg *config.Dendrite Caches *caching.Caches DNSCache *gomatrixserverlib.DNSCache @@ -95,41 +87,27 @@ type BaseDendrite struct { startupLock sync.Mutex } -const NoListener = "" - const HTTPServerTimeout = time.Minute * 5 -const HTTPClientTimeout = time.Second * 30 type BaseDendriteOptions int const ( DisableMetrics BaseDendriteOptions = iota - UseHTTPAPIs - PolylithMode ) // NewBaseDendrite creates a new instance to be used by a component. -// The componentName is used for logging purposes, and should be a friendly name -// of the compontent running, e.g. "SyncAPI" -func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...BaseDendriteOptions) *BaseDendrite { +func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *BaseDendrite { platformSanityChecks() - useHTTPAPIs := false enableMetrics := true - isMonolith := true for _, opt := range options { switch opt { case DisableMetrics: enableMetrics = false - case UseHTTPAPIs: - useHTTPAPIs = true - case PolylithMode: - isMonolith = false - useHTTPAPIs = true } } configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors, isMonolith) + cfg.Verify(configErrors) if len(*configErrors) > 0 { for _, err := range *configErrors { logrus.Errorf("Configuration error: %s", err) @@ -137,7 +115,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base logrus.Fatalf("Failed to start due to configuration errors") } - internal.SetupHookLogging(cfg.Logging, componentName) + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging) internal.SetupPprof() logrus.Infof("Dendrite version %s", internal.VersionString()) @@ -146,14 +125,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base logrus.Warn("Open registration is enabled") } - closer, err := cfg.SetupTracing("Dendrite" + componentName) + closer, err := cfg.SetupTracing() if err != nil { logrus.WithError(err).Panicf("failed to start opentracing") } var fts *fulltext.Search - isSyncOrMonolith := componentName == "syncapi" || isMonolith - if cfg.SyncAPI.Fulltext.Enabled && isSyncOrMonolith { + if cfg.SyncAPI.Fulltext.Enabled { fts, err = fulltext.New(cfg.SyncAPI.Fulltext) if err != nil { logrus.WithError(err).Panicf("failed to create full text") @@ -188,32 +166,12 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base ) } - apiClient := http.Client{ - Timeout: time.Minute * 10, - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // Ordinarily HTTP/2 would expect TLS, but the remote listener is - // H2C-enabled (HTTP/2 without encryption). Overriding the DialTLS - // function with a plain Dial allows us to trick the HTTP client - // into establishing a HTTP/2 connection without TLS. - // TODO: Eventually we will want to look at authenticating and - // encrypting these internal HTTP APIs, at which point we will have - // to reconsider H2C and change all this anyway. - return net.Dial(network, addr) - }, - }, - } - // If we're in monolith mode, we'll set up a global pool of database // connections. A component is welcome to use this pool if they don't // have a separate database config of their own. var db *sql.DB var writer sqlutil.Writer if cfg.Global.DatabaseOptions.ConnectionString != "" { - if !isMonolith { - logrus.Panic("Using a global database connection pool is not supported in polylith deployments") - } if cfg.Global.DatabaseOptions.ConnectionString.IsSQLite() { logrus.Panic("Using a global database connection pool is not supported with SQLite databases") } @@ -238,8 +196,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base return &BaseDendrite{ ProcessContext: process.NewProcessContext(), - componentName: componentName, - UseHTTPAPIs: useHTTPAPIs, tracerCloser: closer, Cfg: cfg, Caches: caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, enableMetrics), @@ -249,11 +205,10 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base PublicKeyAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicKeyPathPrefix).Subrouter().UseEncodedPath(), PublicMediaAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicMediaPathPrefix).Subrouter().UseEncodedPath(), PublicWellKnownAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicWellKnownPrefix).Subrouter().UseEncodedPath(), - InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), + PublicStaticMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicStaticPath).Subrouter().UseEncodedPath(), DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), NATS: &jetstream.NATSInstance{}, - apiHttpClient: &apiClient, Database: db, // set if monolith with global connection pool only DatabaseWriter: writer, // set if monolith with global connection pool only EnableMetrics: enableMetrics, @@ -263,6 +218,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base // Close implements io.Closer func (b *BaseDendrite) Close() error { + b.ProcessContext.ShutdownDendrite() + b.ProcessContext.WaitForShutdown() return b.tracerCloser.Close() } @@ -289,52 +246,6 @@ func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, return nil, nil, fmt.Errorf("no database connections configured") } -// AppserviceHTTPClient returns the AppServiceInternalAPI for hitting the appservice component over HTTP. -func (b *BaseDendrite) AppserviceHTTPClient() appserviceAPI.AppServiceInternalAPI { - a, err := asinthttp.NewAppserviceClient(b.Cfg.AppServiceURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("CreateHTTPAppServiceAPIs failed") - } - return a -} - -// RoomserverHTTPClient returns RoomserverInternalAPI for hitting the roomserver over HTTP. -func (b *BaseDendrite) RoomserverHTTPClient() roomserverAPI.RoomserverInternalAPI { - rsAPI, err := rsinthttp.NewRoomserverClient(b.Cfg.RoomServerURL(), b.apiHttpClient, b.Caches) - if err != nil { - logrus.WithError(err).Panic("RoomserverHTTPClient failed", b.apiHttpClient) - } - return rsAPI -} - -// UserAPIClient returns UserInternalAPI for hitting the userapi over HTTP. -func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { - userAPI, err := userapiinthttp.NewUserAPIClient(b.Cfg.UserAPIURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("UserAPIClient failed", b.apiHttpClient) - } - return userAPI -} - -// FederationAPIHTTPClient returns FederationInternalAPI for hitting -// the federation API server over HTTP -func (b *BaseDendrite) FederationAPIHTTPClient() federationAPI.FederationInternalAPI { - f, err := federationIntHTTP.NewFederationAPIClient(b.Cfg.FederationAPIURL(), b.apiHttpClient, b.Caches) - if err != nil { - logrus.WithError(err).Panic("FederationAPIHTTPClient failed", b.apiHttpClient) - } - return f -} - -// KeyServerHTTPClient returns KeyInternalAPI for hitting the key server over HTTP -func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { - f, err := keyinthttp.NewKeyServerClient(b.Cfg.KeyServerURL(), b.apiHttpClient) - if err != nil { - logrus.WithError(err).Panic("KeyServerHTTPClient failed", b.apiHttpClient) - } - return f -} - // PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) @@ -402,6 +313,7 @@ func (b *BaseDendrite) configureHTTPErrors() { for _, router := range []*mux.Router{ b.PublicMediaAPIMux, b.DendriteAdminMux, b.SynapseAdminMux, b.PublicWellKnownAPIMux, + b.PublicStaticMux, } { router.NotFoundHandler = notFoundCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler @@ -412,56 +324,7 @@ func (b *BaseDendrite) configureHTTPErrors() { b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) } -// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on -// ApiMux under /api/ and adds a prometheus handler under /metrics. -func (b *BaseDendrite) SetupAndServeHTTP( - internalHTTPAddr, externalHTTPAddr config.HTTPAddress, - certFile, keyFile *string, -) { - // Manually unlocked right before actually serving requests, - // as we don't return from this method (defer doesn't work). - b.startupLock.Lock() - internalAddr, _ := internalHTTPAddr.Address() - externalAddr, _ := externalHTTPAddr.Address() - - externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - internalRouter := externalRouter - - externalServ := &http.Server{ - Addr: string(externalAddr), - WriteTimeout: HTTPServerTimeout, - Handler: externalRouter, - BaseContext: func(_ net.Listener) context.Context { - return b.ProcessContext.Context() - }, - } - internalServ := externalServ - - if internalAddr != NoListener && externalAddr != internalAddr { - // H2C allows us to accept HTTP/2 connections without TLS - // encryption. Since we don't currently require any form of - // authentication or encryption on these internal HTTP APIs, - // H2C gives us all of the advantages of HTTP/2 (such as - // stream multiplexing and avoiding head-of-line blocking) - // without enabling TLS. - internalH2S := &http2.Server{} - internalRouter = mux.NewRouter().SkipClean(true).UseEncodedPath() - internalServ = &http.Server{ - Addr: string(internalAddr), - Handler: h2c.NewHandler(internalRouter, internalH2S), - BaseContext: func(_ net.Listener) context.Context { - return b.ProcessContext.Context() - }, - } - } - - b.configureHTTPErrors() - - internalRouter.PathPrefix(httputil.InternalPathPrefix).Handler(b.InternalAPIMux) - if b.Cfg.Global.Metrics.Enabled { - internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) - } - +func (b *BaseDendrite) ConfigureAdminEndpoints() { b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }) @@ -477,6 +340,54 @@ func (b *BaseDendrite) SetupAndServeHTTP( } w.WriteHeader(200) }) +} + +// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs +// and adds a prometheus handler under /_dendrite/metrics. +func (b *BaseDendrite) SetupAndServeHTTP( + externalHTTPAddr config.ServerAddress, + certFile, keyFile *string, +) { + // Manually unlocked right before actually serving requests, + // as we don't return from this method (defer doesn't work). + b.startupLock.Lock() + + externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() + + externalServ := &http.Server{ + Addr: externalHTTPAddr.Address, + WriteTimeout: HTTPServerTimeout, + Handler: externalRouter, + BaseContext: func(_ net.Listener) context.Context { + return b.ProcessContext.Context() + }, + } + + b.configureHTTPErrors() + + //Redirect for Landing Page + externalRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, httputil.PublicStaticPath, http.StatusFound) + }) + + if b.Cfg.Global.Metrics.Enabled { + externalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) + } + + b.ConfigureAdminEndpoints() + + // Parse and execute the landing page template + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + landingPage := &bytes.Buffer{} + if err := tmpl.ExecuteTemplate(landingPage, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }); err != nil { + logrus.WithError(err).Fatal("failed to execute landing page template") + } + + b.PublicStaticMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(landingPage.Bytes()) + }) var clientHandler http.Handler clientHandler = b.PublicClientAPIMux @@ -494,7 +405,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( }) federationHandler = sentryHandler.Handle(b.PublicFederationAPIMux) } - internalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(b.DendriteAdminMux) + externalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(b.DendriteAdminMux) externalRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(clientHandler) if !b.Cfg.Global.DisableFederation { externalRouter.PathPrefix(httputil.PublicKeyPathPrefix).Handler(b.PublicKeyAPIMux) @@ -503,40 +414,14 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(b.SynapseAdminMux) externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) + externalRouter.PathPrefix(httputil.PublicStaticPath).Handler(b.PublicStaticMux) b.startupLock.Unlock() - if internalAddr != NoListener && internalAddr != externalAddr { - go func() { - var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once - logrus.Infof("Starting internal %s listener on %s", b.componentName, internalServ.Addr) - b.ProcessContext.ComponentStarted() - internalServ.RegisterOnShutdown(func() { - if internalShutdown.CompareAndSwap(false, true) { - b.ProcessContext.ComponentFinished() - logrus.Infof("Stopped internal HTTP listener") - } - }) - if certFile != nil && keyFile != nil { - if err := internalServ.ListenAndServeTLS(*certFile, *keyFile); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTPS") - } - } - } else { - if err := internalServ.ListenAndServe(); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTP") - } - } - } - logrus.Infof("Stopped internal %s listener on %s", b.componentName, internalServ.Addr) - }() - } - if externalAddr != NoListener { + if externalHTTPAddr.Enabled() { go func() { var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once - logrus.Infof("Starting external %s listener on %s", b.componentName, externalServ.Addr) + logrus.Infof("Starting external listener on %s", externalServ.Addr) b.ProcessContext.ComponentStarted() externalServ.RegisterOnShutdown(func() { if externalShutdown.CompareAndSwap(false, true) { @@ -551,13 +436,34 @@ func (b *BaseDendrite) SetupAndServeHTTP( } } } else { - if err := externalServ.ListenAndServe(); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTP") + if externalHTTPAddr.IsUnixSocket() { + err := os.Remove(externalHTTPAddr.Address) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + logrus.WithError(err).Fatal("failed to remove existing unix socket") + } + listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address) + if err != nil { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission) + if err != nil { + logrus.WithError(err).Fatal("failed to set unix socket permissions") + } + if err := externalServ.Serve(listener); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + } + + } else { + if err := externalServ.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve HTTP") + } } } } - logrus.Infof("Stopped external %s listener on %s", b.componentName, externalServ.Addr) + logrus.Infof("Stopped external listener on %s", externalServ.Addr) }() } @@ -565,7 +471,6 @@ func (b *BaseDendrite) SetupAndServeHTTP( <-b.ProcessContext.WaitForShutdown() logrus.Infof("Stopping HTTP listeners") - _ = internalServ.Shutdown(context.Background()) _ = externalServ.Shutdown(context.Background()) logrus.Infof("Stopped HTTP listeners") } @@ -588,6 +493,12 @@ func (b *BaseDendrite) WaitForShutdown() { logrus.Warnf("failed to flush all Sentry events!") } } + if b.Fulltext != nil { + err := b.Fulltext.Close() + if err != nil { + logrus.Warnf("failed to close full text search!") + } + } logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/base/base_test.go b/setup/base/base_test.go new file mode 100644 index 000000000..658dc5b03 --- /dev/null +++ b/setup/base/base_test.go @@ -0,0 +1,102 @@ +package base_test + +import ( + "bytes" + "context" + "embed" + "html/template" + "net" + "net/http" + "net/http/httptest" + "path" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +//go:embed static/*.gotmpl +var staticContent embed.FS + +func TestLandingPage_Tcp(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + // hack: create a server and close it immediately, just to get a random port assigned + s := httptest.NewServer(nil) + s.Close() + + // start base with the listener and wait for it to be started + address, err := config.HTTPAddress(s.URL) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) + time.Sleep(time.Millisecond * 10) + + // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page + req, err := http.NewRequest(http.MethodGet, s.URL, nil) + assert.NoError(t, err) + + // do the request + resp, err := s.Client().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} + +func TestLandingPage_UnixSocket(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + tempDir := t.TempDir() + socket := path.Join(tempDir, "socket") + // start base with the listener and wait for it to be started + address := config.UnixSocketAddress(socket, 0755) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) + time.Sleep(time.Millisecond * 100) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socket) + }, + }, + } + resp, err := client.Get("http://unix/") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} diff --git a/setup/base/static/index.gotmpl b/setup/base/static/index.gotmpl new file mode 100644 index 000000000..b3c5576eb --- /dev/null +++ b/setup/base/static/index.gotmpl @@ -0,0 +1,63 @@ + + + + Dendrite is running + + + + +

It works! Dendrite {{ .Version }} is running

+

Your Dendrite server is listening on this port and is ready for messages.

+

To use this server you'll need a Matrix client. +

+

Welcome to the Matrix universe :)

+
+

+ + + matrix.org + + +

+ + diff --git a/setup/config/config.go b/setup/config/config.go index 2b438f988..67106fb1c 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -20,7 +20,6 @@ import ( "encoding/pem" "fmt" "io" - "net/url" "os" "path/filepath" "regexp" @@ -30,7 +29,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" jaegerconfig "github.com/uber/jaeger-client-go/config" jaegermetrics "github.com/uber/jaeger-lib/metrics" @@ -79,8 +78,6 @@ type Dendrite struct { // Any information derived from the configuration options for later use. Derived Derived `yaml:"-"` - - IsMonolith bool `yaml:"-"` } // TODO: Kill Derived @@ -114,15 +111,6 @@ type Derived struct { // servers from creating RoomIDs in exclusive application service namespaces } -type InternalAPIOptions struct { - Listen HTTPAddress `yaml:"listen"` - Connect HTTPAddress `yaml:"connect"` -} - -type ExternalAPIOptions struct { - Listen HTTPAddress `yaml:"listen"` -} - // A Path on the filesystem. type Path string @@ -142,20 +130,6 @@ func (d DataSource) IsPostgres() bool { // A Topic in kafka. type Topic string -// An Address to listen on. -type Address string - -// An HTTPAddress to listen on, starting with either http:// or https://. -type HTTPAddress string - -func (h HTTPAddress) Address() (Address, error) { - url, err := url.Parse(string(h)) - if err != nil { - return "", err - } - return Address(url.Host), nil -} - // FileSizeBytes is a file size in bytes type FileSizeBytes int64 @@ -191,7 +165,7 @@ type ConfigErrors []string // Load a yaml config file for a server run as multiple processes or as a monolith. // Checks the config to ensure that it is valid. -func Load(configPath string, monolith bool) (*Dendrite, error) { +func Load(configPath string) (*Dendrite, error) { configData, err := os.ReadFile(configPath) if err != nil { return nil, err @@ -202,34 +176,32 @@ func Load(configPath string, monolith bool) (*Dendrite, error) { } // Pass the current working directory and os.ReadFile so that they can // be mocked in the tests - return loadConfig(basePath, configData, os.ReadFile, monolith) + return loadConfig(basePath, configData, os.ReadFile) } func loadConfig( basePath string, configData []byte, readFile func(string) ([]byte, error), - monolithic bool, ) (*Dendrite, error) { var c Dendrite c.Defaults(DefaultOpts{ - Generate: false, - Monolithic: monolithic, + Generate: false, + SingleDatabase: true, }) - c.IsMonolith = monolithic var err error if err = yaml.Unmarshal(configData, &c); err != nil { return nil, err } - if err = c.check(monolithic); err != nil { + if err = c.check(); err != nil { return nil, err } privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key: %w", err) } for _, v := range c.Global.VirtualHosts { @@ -243,7 +215,7 @@ func loadConfig( } privateKeyPath := absPath(basePath, v.PrivateKeyPath) if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key for virtualhost %s: %w", v.ServerName, err) } } @@ -324,11 +296,13 @@ func (config *Dendrite) Derive() error { if config.ClientAPI.RecaptchaEnabled { config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}, + } } else { - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}, + } } if config.ClientAPI.ThreePidDelegate != "" { config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, @@ -343,8 +317,8 @@ func (config *Dendrite) Derive() error { } type DefaultOpts struct { - Generate bool - Monolithic bool + Generate bool + SingleDatabase bool } // SetDefaults sets default config values if they are not explicitly set. @@ -364,9 +338,9 @@ func (c *Dendrite) Defaults(opts DefaultOpts) { c.Wiring() } -func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Dendrite) Verify(configErrs *ConfigErrors) { type verifiable interface { - Verify(configErrs *ConfigErrors, isMonolith bool) + Verify(configErrs *ConfigErrors) } for _, c := range []verifiable{ &c.Global, &c.ClientAPI, &c.FederationAPI, @@ -374,7 +348,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.SyncAPI, &c.UserAPI, &c.AppServiceAPI, &c.MSCs, } { - c.Verify(configErrs, isMonolith) + c.Verify(configErrs) } } @@ -423,14 +397,6 @@ func checkNotEmpty(configErrs *ConfigErrors, key, value string) { } } -// checkNotZero verifies the given value is not zero in the configuration. -// If it is, adds an error to the list. -func checkNotZero(configErrs *ConfigErrors, key string, value int64) { - if value == 0 { - configErrs.Add(fmt.Sprintf("missing config key %q", key)) - } -} - // checkPositive verifies the given value is positive (zero included) // in the configuration. If it is not, adds an error to the list. func checkPositive(configErrs *ConfigErrors, key string, value int64) { @@ -439,26 +405,6 @@ func checkPositive(configErrs *ConfigErrors, key string, value int64) { } } -// checkURL verifies that the parameter is a valid URL -func checkURL(configErrs *ConfigErrors, key, value string) { - if value == "" { - configErrs.Add(fmt.Sprintf("missing config key %q", key)) - return - } - url, err := url.Parse(value) - if err != nil { - configErrs.Add(fmt.Sprintf("config key %q contains invalid URL (%s)", key, err.Error())) - return - } - switch url.Scheme { - case "http": - case "https": - default: - configErrs.Add(fmt.Sprintf("config key %q URL should be http:// or https://", key)) - return - } -} - // checkLogging verifies the parameters logging.* are valid. func (config *Dendrite) checkLogging(configErrs *ConfigErrors) { for _, logrusHook := range config.Logging { @@ -469,7 +415,7 @@ func (config *Dendrite) checkLogging(configErrs *ConfigErrors) { // check returns an error type containing all errors found within the config // file. -func (config *Dendrite) check(_ bool) error { // monolithic +func (config *Dendrite) check() error { // monolithic var configErrs ConfigErrors if config.Version != Version { @@ -536,58 +482,13 @@ func readKeyPEM(path string, data []byte, enforceKeyIDFormat bool) (gomatrixserv } } -// AppServiceURL returns a HTTP URL for where the appservice component is listening. -func (config *Dendrite) AppServiceURL() string { - // Hard code the appservice server to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.AppServiceAPI.InternalAPI.Connect) -} - -// FederationAPIURL returns an HTTP URL for where the federation API is listening. -func (config *Dendrite) FederationAPIURL() string { - // Hard code the federationapi to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.FederationAPI.InternalAPI.Connect) -} - -// RoomServerURL returns an HTTP URL for where the roomserver is listening. -func (config *Dendrite) RoomServerURL() string { - // Hard code the roomserver to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.RoomServer.InternalAPI.Connect) -} - -// UserAPIURL returns an HTTP URL for where the userapi is listening. -func (config *Dendrite) UserAPIURL() string { - // Hard code the userapi to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.UserAPI.InternalAPI.Connect) -} - -// KeyServerURL returns an HTTP URL for where the key server is listening. -func (config *Dendrite) KeyServerURL() string { - // Hard code the key server to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.KeyServer.InternalAPI.Connect) -} - // SetupTracing configures the opentracing using the supplied configuration. -func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) { +func (config *Dendrite) SetupTracing() (closer io.Closer, err error) { if !config.Tracing.Enabled { return io.NopCloser(bytes.NewReader([]byte{})), nil } return config.Tracing.Jaeger.InitGlobalTracer( - serviceName, + "Dendrite", jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}), jaegerconfig.Metrics(jaegermetrics.NullFactory), ) diff --git a/setup/config/config_address.go b/setup/config/config_address.go new file mode 100644 index 000000000..0e4f0296f --- /dev/null +++ b/setup/config/config_address.go @@ -0,0 +1,45 @@ +package config + +import ( + "io/fs" + "net/url" +) + +const ( + NetworkTCP = "tcp" + NetworkUnix = "unix" +) + +type ServerAddress struct { + Address string + Scheme string + UnixSocketPermission fs.FileMode +} + +func (s ServerAddress) Enabled() bool { + return s.Address != "" +} + +func (s ServerAddress) IsUnixSocket() bool { + return s.Scheme == NetworkUnix +} + +func (s ServerAddress) Network() string { + if s.Scheme == NetworkUnix { + return NetworkUnix + } else { + return NetworkTCP + } +} + +func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress { + return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm} +} + +func HTTPAddress(urlAddress string) (ServerAddress, error) { + parsedUrl, err := url.Parse(urlAddress) + if err != nil { + return ServerAddress{}, err + } + return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil +} diff --git a/setup/config/config_address_test.go b/setup/config/config_address_test.go new file mode 100644 index 000000000..1be484fd5 --- /dev/null +++ b/setup/config/config_address_test.go @@ -0,0 +1,25 @@ +package config + +import ( + "io/fs" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHttpAddress_ParseGood(t *testing.T) { + address, err := HTTPAddress("http://localhost:123") + assert.NoError(t, err) + assert.Equal(t, "localhost:123", address.Address) + assert.Equal(t, "tcp", address.Network()) +} + +func TestHttpAddress_ParseBad(t *testing.T) { + _, err := HTTPAddress(":") + assert.Error(t, err) +} + +func TestUnixSocketAddress_Network(t *testing.T) { + address := UnixSocketAddress("/tmp", fs.FileMode(0755)) + assert.Equal(t, "unix", address.Network()) +} diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index 706d2dfd2..588d50bd4 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -21,15 +21,13 @@ import ( "regexp" "strings" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) type AppServiceAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - // DisableTLSValidation disables the validation of X.509 TLS certs // on appservice endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -38,18 +36,9 @@ type AppServiceAPI struct { } func (c *AppServiceAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7777" - c.InternalAPI.Connect = "http://localhost:7777" - } } -func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } - checkURL(configErrs, "app_service_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "app_service_api.internal_api.connect", string(c.InternalAPI.Connect)) +func (c *AppServiceAPI) Verify(configErrs *ConfigErrors) { } // ApplicationServiceNamespace is the namespace that a specific application diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 7a24c7e44..0d54573f9 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -12,9 +12,6 @@ type ClientAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // If set disables new users from registering (except via shared // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` @@ -75,11 +72,6 @@ type JwtConfig struct { } func (c *ClientAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7771" - c.InternalAPI.Connect = "http://localhost:7771" - c.ExternalAPI.Listen = "http://[::]:8071" - } c.RegistrationSharedSecret = "" c.RecaptchaPublicKey = "" c.RecaptchaPrivateKey = "" @@ -91,13 +83,10 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) { c.RateLimiting.Defaults() } -func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *ClientAPI) Verify(configErrs *ConfigErrors) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) if c.RecaptchaEnabled { - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) if c.RecaptchaSiteVerifyAPI == "" { c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" } @@ -105,11 +94,15 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" } if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha" + c.RecaptchaFormField = "g-recaptcha-response" } if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha-response" + c.RecaptchaSitekeyClass = "g-recaptcha" } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } // Ensure there is any spam counter measure when enabling registration if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { @@ -124,12 +117,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { ) } } - if isMonolith { // polylith required configs below - return - } - checkURL(configErrs, "client_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "client_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "client_api.external_api.listen", string(c.ExternalAPI.Listen)) } type TURN struct { diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 0f853865f..8c1540b57 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -1,13 +1,12 @@ package config -import "github.com/matrix-org/gomatrixserverlib" +import ( + "github.com/matrix-org/gomatrixserverlib" +) type FederationAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // The database stores information used by the federation destination queues to // send transactions to remote servers. Database DatabaseOptions `yaml:"database,omitempty"` @@ -18,6 +17,12 @@ type FederationAPI struct { // The default value is 16 if not specified, which is circa 18 hours. FederationMaxRetries uint32 `yaml:"send_max_retries"` + // P2P Feature: How many consecutive failures that we should tolerate when + // sending federation requests to a specific server until we should assume they + // are offline. If we assume they are offline then we will attempt to send + // messages to their relay server if we know of one that is appropriate. + P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"` + // FederationDisableTLSValidation disables the validation of X.509 TLS certs // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -36,13 +41,8 @@ type FederationAPI struct { } func (c *FederationAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7772" - c.InternalAPI.Connect = "http://localhost:7772" - c.ExternalAPI.Listen = "http://[::]:8072" - c.Database.Defaults(10) - } c.FederationMaxRetries = 16 + c.P2PFederationRetriesUntilAssumedOffline = 1 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { @@ -61,22 +61,16 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { }, }, } - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:federationapi.db" } } } -func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *FederationAPI) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "federation_api.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "federation_api.external_api.listen", string(c.ExternalAPI.Listen)) - checkURL(configErrs, "federation_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "federation_api.internal_api.connect", string(c.InternalAPI.Connect)) } // The config for setting a proxy to use for server->server requests diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe6..7d3ab6a40 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -38,7 +38,6 @@ type Global struct { // component does not specify any database options of its own, then this pool of // connections will be used instead. This way we don't have to manage connection // counts on a per-component basis, but can instead do it for the entire monolith. - // In a polylith deployment, this will be ignored. DatabaseOptions DatabaseOptions `yaml:"database,omitempty"` // The server name to delegate server-server communications to, with optional port @@ -93,7 +92,7 @@ func (c *Global) Defaults(opts DefaultOpts) { } } c.KeyValidityPeriod = time.Hour * 24 * 7 - if opts.Monolithic { + if opts.SingleDatabase { c.DatabaseOptions.Defaults(90) } c.JetStream.Defaults(opts) @@ -105,7 +104,7 @@ func (c *Global) Defaults(opts DefaultOpts) { c.Cache.Defaults() } -func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Global) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) @@ -113,13 +112,13 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { v.Verify(configErrs) } - c.JetStream.Verify(configErrs, isMonolith) - c.Metrics.Verify(configErrs, isMonolith) - c.Sentry.Verify(configErrs, isMonolith) - c.DNSCache.Verify(configErrs, isMonolith) - c.ServerNotices.Verify(configErrs, isMonolith) - c.ReportStats.Verify(configErrs, isMonolith) - c.Cache.Verify(configErrs, isMonolith) + c.JetStream.Verify(configErrs) + c.Metrics.Verify(configErrs) + c.Sentry.Verify(configErrs) + c.DNSCache.Verify(configErrs) + c.ServerNotices.Verify(configErrs) + c.ReportStats.Verify(configErrs) + c.Cache.Verify(configErrs) } func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { @@ -174,7 +173,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { @@ -267,7 +266,7 @@ func (c *Metrics) Defaults(opts DefaultOpts) { } } -func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Metrics) Verify(configErrs *ConfigErrors) { } // ServerNotices defines the configuration used for sending server notices @@ -293,7 +292,7 @@ func (c *ServerNotices) Defaults(opts DefaultOpts) { } } -func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} +func (c *ServerNotices) Verify(errors *ConfigErrors) {} type Cache struct { EstimatedMaxSize DataUnit `yaml:"max_size_estimated"` @@ -305,7 +304,7 @@ func (c *Cache) Defaults() { c.MaxAge = time.Hour } -func (c *Cache) Verify(errors *ConfigErrors, isMonolith bool) { +func (c *Cache) Verify(errors *ConfigErrors) { checkPositive(errors, "max_size_estimated", int64(c.EstimatedMaxSize)) } @@ -320,10 +319,15 @@ type ReportStats struct { func (c *ReportStats) Defaults() { c.Enabled = false - c.Endpoint = "https://matrix.org/report-usage-stats/push" + c.Endpoint = "https://panopticon.matrix.org/push" } -func (c *ReportStats) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *ReportStats) Verify(configErrs *ConfigErrors) { + // We prefer to hit panopticon (https://github.com/matrix-org/panopticon) directly over + // the "old" matrix.org endpoint. + if c.Endpoint == "https://matrix.org/report-usage-stats/push" { + c.Endpoint = "https://panopticon.matrix.org/push" + } if c.Enabled { checkNotEmpty(configErrs, "global.report_stats.endpoint", c.Endpoint) } @@ -344,7 +348,7 @@ func (c *Sentry) Defaults() { c.Enabled = false } -func (c *Sentry) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *Sentry) Verify(configErrs *ConfigErrors) { } type DatabaseOptions struct { @@ -364,8 +368,7 @@ func (c *DatabaseOptions) Defaults(conns int) { c.ConnMaxLifetimeSeconds = -1 } -func (c *DatabaseOptions) Verify(configErrs *ConfigErrors, isMonolith bool) { -} +func (c *DatabaseOptions) Verify(configErrs *ConfigErrors) {} // MaxIdleConns returns maximum idle connections to the DB func (c DatabaseOptions) MaxIdleConns() int { @@ -397,7 +400,7 @@ func (c *DNSCacheOptions) Defaults() { c.CacheLifetime = time.Minute * 5 } -func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors) { checkPositive(configErrs, "cache_size", int64(c.CacheSize)) checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime)) } diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index ef8bf014b..b8abed25c 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -41,11 +41,4 @@ func (c *JetStream) Defaults(opts DefaultOpts) { } } -func (c *JetStream) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } - // If we are running in a polylith deployment then we need at least - // one NATS JetStream server to talk to. - checkNotZero(configErrs, "global.jetstream.addresses", int64(len(c.Addresses))) -} +func (c *JetStream) Verify(configErrs *ConfigErrors) {} diff --git a/setup/config/config_keyserver.go b/setup/config/config_keyserver.go index dca9ca9f5..64710d957 100644 --- a/setup/config/config_keyserver.go +++ b/setup/config/config_keyserver.go @@ -3,31 +3,19 @@ package config type KeyServer struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` } func (c *KeyServer) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7779" - c.InternalAPI.Connect = "http://localhost:7779" - c.Database.Defaults(10) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:keyserver.db" } } } -func (c *KeyServer) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *KeyServer) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "key_server.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "key_server.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "key_server.internal_api.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/config/config_mediaapi.go b/setup/config/config_mediaapi.go index 53a8219eb..030bc3754 100644 --- a/setup/config/config_mediaapi.go +++ b/setup/config/config_mediaapi.go @@ -7,9 +7,6 @@ import ( type MediaAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - // The MediaAPI database stores information about files uploaded and downloaded // by local users. It is only accessed by the MediaAPI. Database DatabaseOptions `yaml:"database,omitempty"` @@ -39,12 +36,6 @@ type MediaAPI struct { var DefaultMaxFileSizeBytes = FileSizeBytes(10485760) func (c *MediaAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7774" - c.InternalAPI.Connect = "http://localhost:7774" - c.ExternalAPI.Listen = "http://[::]:8074" - c.Database.Defaults(5) - } c.MaxFileSizeBytes = DefaultMaxFileSizeBytes c.MaxThumbnailGenerators = 10 if opts.Generate { @@ -65,14 +56,14 @@ func (c *MediaAPI) Defaults(opts DefaultOpts) { ResizeMethod: "scale", }, } - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:mediaapi.db" } c.BasePath = "./media_store" } } -func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *MediaAPI) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath)) checkPositive(configErrs, "media_api.max_file_size_bytes", int64(c.MaxFileSizeBytes)) checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators)) @@ -81,13 +72,8 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkPositive(configErrs, fmt.Sprintf("media_api.thumbnail_sizes[%d].width", i), int64(size.Width)) checkPositive(configErrs, fmt.Sprintf("media_api.thumbnail_sizes[%d].height", i), int64(size.Height)) } - if isMonolith { // polylith required configs below - return - } + if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "media_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "media_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "media_api.external_api.listen", string(c.ExternalAPI.Listen)) } diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 6d5ff39a5..21d4b4da0 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -14,11 +14,8 @@ type MSCs struct { } func (c *MSCs) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.Database.Defaults(5) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:mscs.db" } } @@ -34,10 +31,7 @@ func (c *MSCs) Enabled(msc string) bool { return false } -func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } diff --git a/setup/config/config_roomserver.go b/setup/config/config_roomserver.go index 5e3b7f2ec..319c2419c 100644 --- a/setup/config/config_roomserver.go +++ b/setup/config/config_roomserver.go @@ -3,31 +3,19 @@ package config type RoomServer struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` } func (c *RoomServer) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7770" - c.InternalAPI.Connect = "http://localhost:7770" - c.Database.Defaults(20) - } if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:roomserver.db" } } } -func (c *RoomServer) Verify(configErrs *ConfigErrors, isMonolith bool) { - if isMonolith { // polylith required configs below - return - } +func (c *RoomServer) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString)) } - checkURL(configErrs, "room_server.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "room_server.internal_ap.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index a87da3732..756f4cfb3 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -3,9 +3,6 @@ package config type SyncAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` - Database DatabaseOptions `yaml:"database,omitempty"` RealIPHeader string `yaml:"real_ip_header"` @@ -14,31 +11,19 @@ type SyncAPI struct { } func (c *SyncAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7773" - c.InternalAPI.Connect = "http://localhost:7773" - c.ExternalAPI.Listen = "http://localhost:8073" - c.Database.Defaults(20) - } c.Fulltext.Defaults(opts) if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.Database.ConnectionString = "file:syncapi.db" } } } -func (c *SyncAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - c.Fulltext.Verify(configErrs, isMonolith) - if isMonolith { // polylith required configs below - return - } +func (c *SyncAPI) Verify(configErrs *ConfigErrors) { + c.Fulltext.Verify(configErrs) if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "sync_api.database", string(c.Database.ConnectionString)) } - checkURL(configErrs, "sync_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "sync_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkURL(configErrs, "sync_api.external_api.listen", string(c.ExternalAPI.Listen)) } type Fulltext struct { @@ -54,7 +39,7 @@ func (f *Fulltext) Defaults(opts DefaultOpts) { f.Language = "en" } -func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (f *Fulltext) Verify(configErrs *ConfigErrors) { if !f.Enabled { return } diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389c..79407f30d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,22 +16,33 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) func TestLoadConfigRelative(t *testing.T) { - _, err := loadConfig("/my/config/dir", []byte(testConfig), + cfg, err := loadConfig("/my/config/dir", []byte(testConfig), mockReadFile{ "/my/config/dir/matrix_key.pem": testKey, "/my/config/dir/tls_cert.pem": testCert, }.readFile, - false, ) if err != nil { t.Error("failed to load config:", err) } + + configErrors := &ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + t.Error("configuration verification failed") + } } const testConfig = ` @@ -66,10 +77,9 @@ global: display_name: "Server alerts" avatar: "" room_name: "Server Alerts" + jetstream: + addresses: ["test"] app_service_api: - internal_api: - listen: http://localhost:7777 - connect: http://localhost:7777 database: connection_string: file:appservice.db max_open_conns: 100 @@ -77,12 +87,7 @@ app_service_api: conn_max_lifetime: -1 config_files: [] client_api: - internal_api: - listen: http://localhost:7771 - connect: http://localhost:7771 - external_api: - listen: http://[::]:8071 - registration_disabled: false + registration_disabled: true registration_shared_secret: "" enable_registration_captcha: false recaptcha_public_key: "" @@ -95,36 +100,16 @@ client_api: turn_shared_secret: "" turn_username: "" turn_password: "" -current_state_server: - internal_api: - listen: http://localhost:7782 - connect: http://localhost:7782 - database: - connection_string: file:currentstate.db - max_open_conns: 100 - max_idle_conns: 2 - conn_max_lifetime: -1 federation_api: - internal_api: - listen: http://localhost:7772 - connect: http://localhost:7772 - external_api: - listen: http://[::]:8072 + database: + connection_string: file:federationapi.db key_server: - internal_api: - listen: http://localhost:7779 - connect: http://localhost:7779 database: connection_string: file:keyserver.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 media_api: - internal_api: - listen: http://localhost:7774 - connect: http://localhost:7774 - external_api: - listen: http://[::]:8074 database: connection_string: file:mediaapi.db max_open_conns: 100 @@ -145,18 +130,12 @@ media_api: height: 480 method: scale room_server: - internal_api: - listen: http://localhost:7770 - connect: http://localhost:7770 database: connection_string: file:roomserver.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 server_key_api: - internal_api: - listen: http://localhost:7780 - connect: http://localhost:7780 database: connection_string: file:serverkeyapi.db max_open_conns: 100 @@ -170,18 +149,12 @@ server_key_api: - key_id: ed25519:a_RXGa public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ sync_api: - internal_api: - listen: http://localhost:7773 - connect: http://localhost:7773 database: connection_string: file:syncapi.db max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 user_api: - internal_api: - listen: http://localhost:7781 - connect: http://localhost:7781 account_database: connection_string: file:userapi_accounts.db max_open_conns: 100 @@ -192,6 +165,12 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 +relay_api: + database: + connection_string: file:relayapi.db +mscs: + database: + connection_string: file:mscs.db tracing: enabled: false jaeger: @@ -290,3 +269,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index f8ad41d93..e64a3910c 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -5,8 +5,6 @@ import "golang.org/x/crypto/bcrypt" type UserAPI struct { Matrix *Global `yaml:"-"` - InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` - // The cost when hashing passwords. BCryptCost int `yaml:"bcrypt_cost"` @@ -28,28 +26,18 @@ type UserAPI struct { const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes func (c *UserAPI) Defaults(opts DefaultOpts) { - if !opts.Monolithic { - c.InternalAPI.Listen = "http://localhost:7781" - c.InternalAPI.Connect = "http://localhost:7781" - c.AccountDatabase.Defaults(10) - } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS if opts.Generate { - if !opts.Monolithic { + if !opts.SingleDatabase { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" } } } -func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { +func (c *UserAPI) Verify(configErrs *ConfigErrors) { checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) - if isMonolith { // polylith required configs below - return - } if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) } - checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) } diff --git a/setup/flags.go b/setup/flags.go index a9dac61a1..869caa280 100644 --- a/setup/flags.go +++ b/setup/flags.go @@ -43,7 +43,7 @@ func ParseFlags(monolith bool) *config.Dendrite { logrus.Fatal("--config must be supplied") } - cfg, err := config.Load(*configPath, monolith) + cfg, err := config.Load(*configPath) if err != nil { logrus.Fatalf("Invalid config file: %s", err) diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index c1ce9583f..533652160 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -77,6 +77,11 @@ func JetStreamConsumer( // The consumer was deleted so stop. return } else { + // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string + if err.Error() == "nats: Server Shutdown" { + logrus.WithContext(ctx).Warn("nats server shutting down") + return + } // Something else went wrong, so we'll panic. sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) diff --git a/setup/jetstream/log.go b/setup/jetstream/log.go new file mode 100644 index 000000000..880f7120b --- /dev/null +++ b/setup/jetstream/log.go @@ -0,0 +1,42 @@ +package jetstream + +import ( + "github.com/nats-io/nats-server/v2/server" + "github.com/sirupsen/logrus" +) + +var _ server.Logger = &LogAdapter{} + +type LogAdapter struct { + entry *logrus.Entry +} + +func NewLogAdapter() *LogAdapter { + return &LogAdapter{ + entry: logrus.StandardLogger().WithField("component", "jetstream"), + } +} + +func (l *LogAdapter) Noticef(format string, v ...interface{}) { + l.entry.Infof(format, v...) +} + +func (l *LogAdapter) Warnf(format string, v ...interface{}) { + l.entry.Warnf(format, v...) +} + +func (l *LogAdapter) Fatalf(format string, v ...interface{}) { + l.entry.Fatalf(format, v...) +} + +func (l *LogAdapter) Errorf(format string, v ...interface{}) { + l.entry.Errorf(format, v...) +} + +func (l *LogAdapter) Debugf(format string, v ...interface{}) { + l.entry.Debugf(format, v...) +} + +func (l *LogAdapter) Tracef(format string, v ...interface{}) { + l.entry.Tracef(format, v...) +} diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index adaeb873d..f7f245d36 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -40,7 +40,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS } if s.Server == nil { var err error - s.Server, err = natsserver.NewServer(&natsserver.Options{ + opts := &natsserver.Options{ ServerName: "monolith", DontListen: true, JetStream: true, @@ -49,11 +49,12 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS MaxPayload: 16 * 1024 * 1024, NoSigs: true, NoLog: cfg.NoLog, - }) + } + s.Server, err = natsserver.NewServer(opts) if err != nil { panic(err) } - s.ConfigureLogger() + s.SetLogger(NewLogAdapter(), opts.Debug, opts.Trace) go func() { process.ComponentStarted() s.Start() diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..54bab2dcc 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -43,7 +42,6 @@ type Monolith struct { FederationAPI federationAPI.FederationInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI - KeyAPI keyAPI.KeyInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -58,17 +56,13 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { } clientapi.AddPublicRoutes( base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, + m.FederationAPI, m.UserAPI, userDirectoryProvider, m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( - base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, - m.KeyAPI, nil, - ) - mediaapi.AddPublicRoutes( - base, m.UserAPI, m.Client, - ) - syncapi.AddPublicRoutes( - base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, + base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, ) + mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client) + syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI) + } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index bc369c166..4bb6a5eee 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo var res MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - event := rc.getLocalEvent(rc.req.EventID) + event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID) if event == nil { event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) } @@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation // lookForEvent returns the event for the event ID given, by trying to query remote servers // if the event ID is unknown via /event_relationships. func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { - event := rc.getLocalEvent(eventID) + event := rc.getLocalEvent(rc.req.RoomID, eventID) if event == nil { queryRes := rc.remoteEventRelationships(eventID) if queryRes != nil { @@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent return nil } -func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &queryEventsRes) if err != nil { diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 0388fcc53..f12fbbfcb 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -499,7 +499,7 @@ func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relT } type testUserAPI struct { - userapi.UserInternalAPITrace + userapi.UserInternalAPI accessTokens map[string]userapi.Device } @@ -516,7 +516,7 @@ func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace + roomserver.RoomserverInternalAPI userToJoinedRooms map[string][]string events map[string]*gomatrixserverlib.HeaderedEvent } @@ -548,8 +548,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve t.Helper() cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = "localhost" cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 92f081500..5faaefb8e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -19,7 +19,6 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -28,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 145059c2d..6e3150c29 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error { // Normal NATS subscription, used by Request/Reply _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { userID := msg.Header.Get(jetstream.UserID) - presence, err := s.db.GetPresence(context.Background(), userID) + presences, err := s.db.GetPresences(context.Background(), []string{userID}) m := &nats.Msg{ Header: nats.Header{}, } @@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error { } return } - if presence == nil { - presence = &types.PresenceInternal{ - UserID: userID, - } + + presence := &types.PresenceInternal{ + UserID: userID, + } + if len(presences) > 0 { + presence = presences[0] } deviceRes := api.QueryDevicesResponse{} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1b67f5684..a8d4d2b2c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms s.onRetirePeek(s.ctx, *output.RetirePeek) case api.OutputTypeRedactedEvent: err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + case api.OutputTypePurgeRoom: + err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) + if err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room + } default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -205,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ev.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -473,6 +481,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek( s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp}) } +func (s *OutputRoomEventConsumer) onPurgeRoom( + ctx context.Context, req api.OutputPurgeRoom, +) error { + logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API") + + if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API") + return err + } else { + logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API") + return nil + } +} + func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 356e83263..32208c585 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -25,7 +25,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -33,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. @@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct { durable string topic string db storage.Database - keyAPI keyapi.SyncKeyAPI + userAPI api.SyncKeyAPI isLocalServerName func(gomatrixserverlib.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier @@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, js nats.JetStreamContext, store storage.Database, - keyAPI keyapi.SyncKeyAPI, + userAPI api.SyncKeyAPI, notifier *notifier.Notifier, stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { @@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer( topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), db: store, - keyAPI: keyAPI, + userAPI: userAPI, isLocalServerName: cfg.Matrix.IsLocalServerName, notifier: notifier, stream: stream, @@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. - if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ + if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, }, &struct{}{}); err != nil { logger.WithError(err).Errorf("failed to mark as stale if needed") diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 71d7ddd15..ee695f0f5 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -33,8 +33,7 @@ func init() { } // calculateHistoryVisibilityDuration stores the time it takes to -// calculate the history visibility. In polylith mode the roundtrip -// to the roomserver is included in this time. +// calculate the history visibility. var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "dendrite", @@ -121,10 +120,7 @@ func ApplyHistoryVisibilityFilter( // Get the mapping from eventID -> eventVisibility eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) - visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) - if err != nil { - return eventsFiltered, err - } + visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) for _, ev := range events { evVis := visibilities[ev.EventID()] evVis.membershipCurrent = membershipCurrent @@ -175,7 +171,7 @@ func visibilityForEvents( rsAPI api.SyncRoomserverAPI, events []*gomatrixserverlib.HeaderedEvent, userID, roomID string, -) (map[string]eventVisibility, error) { +) map[string]eventVisibility { eventIDs := make([]string, len(events)) for i := range events { eventIDs[i] = events[i].EventID() @@ -185,6 +181,7 @@ func visibilityForEvents( // get the membership events for all eventIDs membershipResp := &api.QueryMembershipAtEventResponse{} + err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ RoomID: roomID, EventIDs: eventIDs, @@ -201,19 +198,20 @@ func visibilityForEvents( membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident visibility: event.Visibility, } - membershipEvs, ok := membershipResp.Memberships[eventID] - if !ok { + ev, ok := membershipResp.Membership[eventID] + if !ok || ev == nil { result[eventID] = vis continue } - for _, ev := range membershipEvs { - membership, err := ev.Membership() - if err != nil { - return result, err - } - vis.membershipAtEvent = membership + + membership, err := ev.Membership() + if err != nil { + result[eventID] = vis + continue } + vis.membershipAtEvent = membership + result[eventID] = vis } - return result, nil + return result } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 3d6b2a7f3..e7f677c85 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,22 +18,22 @@ import ( "context" "strings" + keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // DeviceOTKCounts adds one-time key counts to the /sync response -func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { - var queryRes keyapi.QueryOneTimeKeysResponse - _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ +func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + _ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. func DeviceListCatchup( - ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, + ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, res *types.Response, from, to types.StreamPosition, ) (newPos types.StreamPosition, hasNew bool, err error) { @@ -74,8 +74,8 @@ func DeviceListCatchup( if from > 0 { offset = int64(from) } - var queryRes keyapi.QueryKeyChangesResponse - _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + var queryRes api.QueryKeyChangesResponse + _ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) @@ -144,38 +144,42 @@ func TrackChangedUsers( // - Loop set of users and decrement by 1 for each user in newly left room. // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. var queryRes roomserverAPI.QuerySharedUsersResponse - err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, - IncludeRoomIDs: newlyLeftRooms, - }, &queryRes) - if err != nil { - return nil, nil, err - } var stateRes roomserverAPI.QueryBulkStateContentResponse - err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ - RoomIDs: newlyLeftRooms, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: "*", - }, - }, - AllowWildcards: true, - }, &stateRes) - if err != nil { - return nil, nil, err - } - for _, state := range stateRes.Rooms { - for tuple, membership := range state { - if membership != gomatrixserverlib.Join { - continue - } - queryRes.UserIDsToCount[tuple.StateKey]-- + if len(newlyLeftRooms) > 0 { + err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ + UserID: userID, + IncludeRoomIDs: newlyLeftRooms, + }, &queryRes) + if err != nil { + return nil, nil, err } - } - for userID, count := range queryRes.UserIDsToCount { - if count <= 0 { - left = append(left, userID) // left is returned + + err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ + RoomIDs: newlyLeftRooms, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: "*", + }, + }, + AllowWildcards: true, + }, &stateRes) + if err != nil { + return nil, nil, err + } + for _, state := range stateRes.Rooms { + for tuple, membership := range state { + if membership != gomatrixserverlib.Join { + continue + } + queryRes.UserIDsToCount[tuple.StateKey]-- + } + } + + for userID, count := range queryRes.UserIDsToCount { + if count <= 0 { + left = append(left, userID) // left is returned + } } } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 53f3e5a40..4bb851668 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -22,49 +21,49 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error { +func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error { return nil } -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error { return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { return nil } type mockRoomserverAPI struct { - api.RoomserverInternalAPITrace + api.RoomserverInternalAPI roomIDToJoinedMembers map[string][]string } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 095a868c7..76f003671 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -67,6 +67,8 @@ func Context( errMsg = "unable to parse filter" case *strconv.NumError: errMsg = "unable to parse limit" + default: + errMsg = err.Error() } return util.JSONResponse{ Code: http.StatusBadRequest, @@ -167,7 +169,18 @@ func Context( eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) - newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + + newState := state + if filter.LazyLoadMembers { + allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) + allEvents = append(allEvents, &requestedEvent) + evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll) + newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) + if err != nil { + logrus.WithError(err).Error("unable to load membership events") + return jsonerror.InternalServerError() + } + } ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ @@ -244,41 +257,43 @@ func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, star } func applyLazyLoadMembers( + ctx context.Context, device *userapi.Device, - filter *gomatrixserverlib.RoomEventFilter, - eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, - state []*gomatrixserverlib.HeaderedEvent, + snapshot storage.DatabaseTransaction, + roomID string, + events []gomatrixserverlib.ClientEvent, lazyLoadCache caching.LazyLoadCache, -) []*gomatrixserverlib.HeaderedEvent { - if filter == nil || !filter.LazyLoadMembers { - return state - } - allEvents := append(eventsBefore, eventsAfter...) - x := make(map[string]struct{}) +) ([]*gomatrixserverlib.HeaderedEvent, error) { + eventSenders := make(map[string]struct{}) // get members who actually send an event - for _, e := range allEvents { + for _, e := range events { // Don't add membership events the client should already know about if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached { continue } - x[e.Sender] = struct{}{} + eventSenders[e.Sender] = struct{}{} } - newState := []*gomatrixserverlib.HeaderedEvent{} - membershipEvents := []*gomatrixserverlib.HeaderedEvent{} - for _, event := range state { - if event.Type() != gomatrixserverlib.MRoomMember { - newState = append(newState, event) - } else { - // did the user send an event? - if _, ok := x[event.Sender()]; ok { - membershipEvents = append(membershipEvents, event) - lazyLoadCache.StoreLazyLoadedUser(device, event.RoomID(), event.Sender(), event.EventID()) - } - } + wantUsers := make([]string, 0, len(eventSenders)) + for userID := range eventSenders { + wantUsers = append(wantUsers, userID) } - // Add the membershipEvents to the end of the list, to make Sytest happy - return append(newState, membershipEvents...) + + // Query missing membership events + filter := gomatrixserverlib.DefaultStateFilter() + filter.Senders = &wantUsers + filter.Types = &[]string{gomatrixserverlib.MRoomMember} + memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) + if err != nil { + return nil, err + } + + // cache the membership events + for _, membership := range memberships { + lazyLoadCache.StoreLazyLoadedUser(device, roomID, *membership.StateKey(), membership.EventID()) + } + + return memberships, nil } func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 3fcc3235c..8efd77cef 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -16,16 +16,16 @@ package routing import ( "encoding/json" + "math" "net/http" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type getMembershipResponse struct { @@ -87,19 +87,18 @@ func GetMemberships( if err != nil { return jsonerror.InternalServerError() } + defer db.Rollback() // nolint: errcheck atToken, err := types.NewTopologyTokenFromString(at) if err != nil { + atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} if queryRes.HasBeenInRoom && !queryRes.IsInRoom { // If you have left the room then this will be the members of the room when you left. atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) - } else { - // If you are joined to the room then this will be the current members of the room. - atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) - } - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } } } @@ -110,7 +109,7 @@ func GetMemberships( } qryRes := &api.QueryEventsByIDResponse{} - if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil { + if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") return jsonerror.InternalServerError() } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 0d740ebfc..02d8fcc7e 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -17,6 +17,7 @@ package routing import ( "context" "fmt" + "math" "net/http" "sort" "time" @@ -57,12 +58,13 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end,omitempty"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +// nolint:gocyclo func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, rsAPI api.SyncRoomserverAPI, @@ -177,10 +179,11 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") - return jsonerror.InternalServerError() + to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + if backwardOrdering { + // go 1 earlier than the first event so we correctly fetch the earliest event + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.TopologyToken{} } wasToProvided = false } @@ -244,7 +247,14 @@ func OnIncomingMessagesRequest( Start: start.String(), End: end.String(), } - res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) + if filter.LazyLoadMembers { + membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") + return jsonerror.InternalServerError() + } + res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...) + } // If we didn't return any events, set the end to an empty string, so it will be omitted // in the response JSON. @@ -263,40 +273,6 @@ func OnIncomingMessagesRequest( } } -// applyLazyLoadMembers loads membership events for users returned in Chunk, if the filter has -// LazyLoadMembers enabled. -func (m *messagesResp) applyLazyLoadMembers( - ctx context.Context, - db storage.DatabaseTransaction, - roomID string, - device *userapi.Device, - lazyLoad bool, - lazyLoadCache caching.LazyLoadCache, -) { - if !lazyLoad { - return - } - membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) - for _, evt := range m.Chunk { - // Don't add membership events the client should already know about - if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached { - continue - } - membership, err := db.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, evt.Sender) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to get membership event for user") - continue - } - if membership != nil { - membershipToUser[evt.Sender] = membership - lazyLoadCache.StoreLazyLoadedUser(device, roomID, evt.Sender, membership.EventID()) - } - } - for _, evt := range membershipToUser { - m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) - } -} - func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, @@ -577,24 +553,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] return events, nil } - -// setToDefault returns the default value for the "to" query parameter of a -// request to /messages if not provided. It defaults to either the earliest -// topological position (if we're going backward) or to the latest one (if we're -// going forward). -// Returns an error if there was an issue with retrieving the latest position -// from the database -func setToDefault( - ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, - roomID string, -) (to types.TopologyToken, err error) { - if backwardOrdering { - // go 1 earlier than the first event so we correctly fetch the earliest event - // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. - to = types.TopologyToken{} - } else { - to, err = snapshot.MaxTopologicalPosition(ctx, roomID) - } - - return -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 12c43cba8..26e4f6f6c 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -45,8 +45,8 @@ type DatabaseTransaction interface { GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) - GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) - RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) + RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) @@ -84,8 +84,6 @@ type DatabaseTransaction interface { EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. @@ -106,7 +104,7 @@ type DatabaseTransaction interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) SelectMultiRoomData(ctx context.Context, r *types.Range, joinedRooms []string) (types.MultiRoom, error) @@ -135,6 +133,8 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error + // PurgeRoom entirely eliminates a room from the sync API, timeline, state and all. + PurgeRoom(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) @@ -187,7 +187,7 @@ type Database interface { } type Presence interface { - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 8fc92091f..c20d860a7 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 48ed20021..0d607b7c0 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3 +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -122,6 +132,8 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectRoomHeroesStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - return nil, err - } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectCurrentStateStmt, selectCurrentStateSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL}, + {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectSharedUsersStmt, selectSharedUsersSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + {&s.selectRoomHeroesStmt, selectRoomHeroes}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -282,6 +275,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) senders, notSenders := getSendersStateFilterFilter(stateFilter) + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(senders), pq.StringArray(notSenders), @@ -447,3 +449,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt) + rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/build/dendritejs-pinecone/main_test.go b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go similarity index 58% rename from build/dendritejs-pinecone/main_test.go rename to syncapi/storage/postgres/deltas/20230201152200_rename_index.go index 17fea6cce..5a0ec5050 100644 --- a/build/dendritejs-pinecone/main_test.go +++ b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build wasm -// +build wasm - -package main +package deltas import ( - "testing" + "context" + "database/sql" + "fmt" ) -func TestStartup(t *testing.T) { - startup() +func UpRenameOutputRoomEventsIndex(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE syncapi_output_room_events RENAME CONSTRAINT syncapi_event_id_idx TO syncapi_output_room_event_id_idx;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index aada70d5e..151bffa5d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index b555e8456..47833893a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,10 +19,8 @@ import ( "database/sql" "fmt" - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,25 +62,25 @@ const selectMembershipCountSQL = "" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + const selectMembersSQL = ` -SELECT event_id FROM ( - SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC -) t -WHERE ($3::text IS NULL OR t.membership = $3) - AND ($4::text IS NULL OR t.membership <> $4) + SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC + ) t + WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) ` type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -95,8 +93,8 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, - {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -129,26 +127,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships)) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. @@ -166,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 2c7b24800..7edfd54a6 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, }.Prepare(db) } @@ -44,6 +45,7 @@ type notificationDataStatements struct { upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt } const notificationDataSchema = ` @@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) return @@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3b69b26f6..59fb99aa3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -19,18 +19,17 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "sort" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -44,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( -- This isn't a problem for us since we just want to order by this field. id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), -- The event ID for the event - event_id TEXT NOT NULL CONSTRAINT syncapi_event_id_idx UNIQUE, + event_id TEXT NOT NULL CONSTRAINT syncapi_output_room_event_id_idx UNIQUE, -- The 'room_id' key for the event. room_id TEXT NOT NULL, -- The headered JSON for the event, containing potentially additional metadata such as @@ -79,13 +78,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_room_id_idx ON syncapi_out CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL)); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL)); +CREATE INDEX IF NOT EXISTS syncapi_output_room_events_recent_events_idx ON syncapi_output_room_events (room_id, exclude_from_sync, id, sender, type); + + ` const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + - "ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + + "ON CONFLICT ON CONSTRAINT syncapi_output_room_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + "RETURNING id" const selectEventsSQL = "" + @@ -109,14 +111,29 @@ const selectRecentEventsSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id DESC LIMIT $8" -const selectRecentEventsForSyncSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + - " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + - " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + - " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + - " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + - " ORDER BY id DESC LIMIT $8" +// selectRecentEventsForSyncSQL contains an optimization to get the recent events for a list of rooms, using a LATERAL JOIN +// The sub select inside LATERAL () is executed for all room_ids it gets as a parameter $1 +const selectRecentEventsForSyncSQL = ` +WITH room_ids AS ( + SELECT unnest($1::text[]) AS room_id +) +SELECT x.* +FROM room_ids, + LATERAL ( + SELECT room_id, event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility + FROM syncapi_output_room_events recent_events + WHERE + recent_events.room_id = room_ids.room_id + AND recent_events.exclude_from_sync = FALSE + AND id > $2 AND id <= $3 + AND ( $4::text[] IS NULL OR sender = ANY($4) ) + AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) ) + AND ( $6::text[] IS NULL OR type LIKE ANY($6) ) + AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) ) + ORDER BY recent_events.id DESC + LIMIT $8 + ) AS x +` const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + @@ -176,6 +193,9 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { @@ -193,6 +213,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt selectSearchStmt *sql.Stmt } @@ -203,12 +224,30 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return nil, err } + migrationName := "syncapi: rename dupe index (output_room_events)" + + var cName string + err = db.QueryRowContext(context.Background(), "select constraint_name from information_schema.table_constraints where table_name = 'syncapi_output_room_events' AND constraint_name = 'syncapi_event_id_idx'").Scan(&cName) + switch err { + case sql.ErrNoRows: // migration was already executed, as the index was renamed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return nil, fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + case nil: + default: + return nil, err + } + m := sqlutil.NewMigrator(db) m.AddMigrations( sqlutil.Migration{ Version: "syncapi: add history visibility column (output_room_events)", Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, }, + sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRenameOutputRoomEventsIndex, + }, ) err = m.Up(context.Background()) if err != nil { @@ -230,6 +269,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -393,9 +433,9 @@ func (s *outputRoomEventsStatements) InsertEvent( // from sync. func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var stmt *sql.Stmt if onlySyncEvents { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) @@ -403,8 +443,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } senders, notSenders := getSendersRoomEventFilter(eventFilter) + rows, err := stmt.QueryContext( - ctx, roomID, r.Low(), r.High(), + ctx, pq.StringArray(roomIDs), ra.Low(), ra.High(), pq.StringArray(senders), pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), @@ -412,34 +453,80 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( eventFilter.Limit+1, ) if err != nil { - return nil, false, err + return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + + result := make(map[string]types.RecentEvents) + + for rows.Next() { + var ( + roomID string + eventID string + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID + historyVisibility gomatrixserverlib.HistoryVisibility + ) + if err := rows.Scan(&roomID, &eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil { + return nil, err } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + if sessionID != nil && txnID != nil { + transactionID = &api.TransactionID{ + SessionID: *sessionID, + TransactionID: *txnID, + } + } + + r := result[roomID] + + ev.Visibility = historyVisibility + r.Events = append(r.Events, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, + }) + + result[roomID] = r } - return events, limited, nil + if chronologicalOrder { + for roomID, evs := range result { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(evs.Events, func(i int, j int) bool { + return evs.Events[i].StreamPosition < evs.Events[j].StreamPosition + }) + + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[1:] + } + + result[roomID] = evs + } + } else { + for roomID, evs := range result { + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[:len(evs.Events)-1] + } + + result[roomID] = evs + } + } + return result, rows.Err() } // selectEarlyEvents returns the earliest events in the given room, starting @@ -658,6 +745,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return result, rows.Err() } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 6fab900eb..2382fca5c 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,11 +18,12 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -65,28 +66,23 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" - // Select the max topological position for the room, then sort by stream position and take the highest, - // returning both topological and stream positions. -const selectMaxPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE topological_position=(" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + - ") ORDER BY stream_position DESC LIMIT 1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -95,28 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // InsertEventInTopology inserts the given event in the room's topology, based @@ -190,9 +173,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( return } -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err } diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index e20a4882f..64183073d 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB insertPeekStmt *sql.Stmt @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { @@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { s := &peekStatements{ db: db, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 6f0aa8991..92e603b37 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -67,9 +69,9 @@ SET last_active_ts = $1 WHERE user_id = $2` const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id = ANY($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -137,20 +139,28 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, - } + userIDs []string, +) ([]*types.PresenceInternal, error) { + result := make([]*types.PresenceInternal, 0, len(userIDs)) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + rows, err := stmt.QueryContext(ctx, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 327a7a372..0fcbebfcb 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -62,11 +62,15 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { @@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { r := &receiptStatements{ db: db, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { @@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 7598a64c4..d5a1c564f 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -61,31 +61,23 @@ type Database struct { } func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { - return d.NewDatabaseTransaction(ctx) - - /* - TODO: Repeatable read is probably the right thing to do here, - but it seems to cause some problems with the invite tests, so - need to investigate that further. - - txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) - if err != nil { - return nil, err - } - return &DatabaseTransaction{ - Database: d, - ctx: ctx, - txn: txn, - }, nil - */ + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil } func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { @@ -254,20 +246,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, @@ -583,8 +561,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t return pos, err } -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, nil, userIDs) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { @@ -600,9 +578,14 @@ func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64 } func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { + // No need to unmarshal if the event is a redaction + if event.Type() == gomatrixserverlib.MRoomRedaction { + return nil + } var content gomatrixserverlib.RelationContent if err := json.Unmarshal(event.Content(), &content); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) + logrus.WithError(err).Error("unable to unmarshal relation content") + return nil } switch { case content.Relations == nil: @@ -611,8 +594,6 @@ func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib return nil case content.Relations.RelationType == "": return nil - case event.Type() == gomatrixserverlib.MRoomRedaction: - return nil default: return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Relations.InsertRelation( diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 1faafa3da..d69bbc67a 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "fmt" + "math" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" @@ -92,12 +94,65 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) } -func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { + summary := &types.Summary{Heroes: []string{}} + + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + if err != nil { + return summary, err + } + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + if err != nil { + return summary, err + } + summary.InvitedMemberCount = &inviteCount + summary.JoinedMemberCount = &joinCount + + // Get the room name and canonical alias, if any + filter := gomatrixserverlib.DefaultStateFilter() + filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filterRooms := []string{roomID} + + filter.Types = &filterTypes + filter.Rooms = &filterRooms + evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil) + if err != nil { + return summary, err + } + + for _, ev := range evs { + switch ev.Type() { + case gomatrixserverlib.MRoomName: + if gjson.GetBytes(ev.Content(), "name").Str != "" { + return summary, nil + } + case gomatrixserverlib.MRoomCanonicalAlias: + if gjson.GetBytes(ev.Content(), "alias").Str != "" { + return summary, nil + } + } + } + + // If there's no room name or canonical alias, get the room heroes, excluding the user + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + if err != nil { + return summary, err + } + + // "When no joined or invited members are available, this should consist of the banned and left users" + if len(heroes) == 0 { + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + if err != nil { + return summary, err + } + } + summary.Heroes = heroes + + return summary, nil } -func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents) } func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { @@ -215,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom( return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } -func (d *DatabaseTransaction) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - func (d *DatabaseTransaction) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { @@ -243,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil case err != nil: // some other error happened return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) default: @@ -329,19 +370,25 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered := state + // avoid hitting the database if the result would be the same as above + if !isStatefilterEmpty(stateFilter) { + var stateNeededFiltered map[string]map[string]bool + var eventMapFiltered map[string]types.StreamEvent + stateNeededFiltered, eventMapFiltered, err = d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err - } - stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered, err = d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err } // find out which rooms this user is peeking, if any. @@ -596,8 +643,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } -func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { @@ -608,11 +655,80 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) return d.Presence.GetMaxPresenceID(ctx, d.txn) } +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge backward extremities: %w", err) + } + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge current room state: %w", err) + } + if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge invites: %w", err) + } + if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge memberships: %w", err) + } + if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge notification data: %w", err) + } + if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events: %w", err) + } + if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events topology: %w", err) + } + if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge peeks: %w", err) + } + if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge receipts: %w", err) + } + return nil + }) +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) { id, err := d.Relations.SelectMaxRelationID(ctx, d.txn) return types.StreamPosition(id), err } +func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool { + if filter == nil { + return true + } + switch { + case filter.NotTypes != nil && len(*filter.NotTypes) > 0: + return false + case filter.Types != nil && len(*filter.Types) > 0: + return false + case filter.Senders != nil && len(*filter.Senders) > 0: + return false + case filter.NotSenders != nil && len(*filter.NotSenders) > 0: + return false + case filter.NotRooms != nil && len(*filter.NotRooms) > 0: + return false + case filter.ContainsURL != nil: + return false + default: + return true + } +} + func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) ( events []types.StreamEvent, prevBatch, nextBatch string, err error, ) { diff --git a/syncapi/storage/shared/storage_sync_test.go b/syncapi/storage/shared/storage_sync_test.go new file mode 100644 index 000000000..c56720db7 --- /dev/null +++ b/syncapi/storage/shared/storage_sync_test.go @@ -0,0 +1,72 @@ +package shared + +import ( + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +func Test_isStatefilterEmpty(t *testing.T) { + filterSet := []string{"a"} + boolValue := false + + tests := []struct { + name string + filter *gomatrixserverlib.StateFilter + want bool + }{ + { + name: "nil filter is empty", + filter: nil, + want: true, + }, + { + name: "Empty filter is empty", + filter: &gomatrixserverlib.StateFilter{}, + want: true, + }, + { + name: "NotTypes is set", + filter: &gomatrixserverlib.StateFilter{ + NotTypes: &filterSet, + }, + }, + { + name: "Types is set", + filter: &gomatrixserverlib.StateFilter{ + Types: &filterSet, + }, + }, + { + name: "Senders is set", + filter: &gomatrixserverlib.StateFilter{ + Senders: &filterSet, + }, + }, + { + name: "NotSenders is set", + filter: &gomatrixserverlib.StateFilter{ + NotSenders: &filterSet, + }, + }, + { + name: "NotRooms is set", + filter: &gomatrixserverlib.StateFilter{ + NotRooms: &filterSet, + }, + }, + { + name: "ContainsURL is set", + filter: &gomatrixserverlib.StateFilter{ + ContainsURL: &boolValue, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isStatefilterEmpty(tt.filter); got != tt.want { + t.Errorf("isStatefilterEmpty() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 3a5fd6be3..2d8cf2ed2 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 7a381f68b..35b746c5c 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" @@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3) +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -107,6 +117,8 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic + selectMembershipCountStmt *sql.Stmt + //selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - // return nil, err - //} - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -270,6 +267,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } stmt, params, err := prepareWithFilters( s.db, txn, selectCurrentStateSQL, []interface{}{ @@ -485,3 +491,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + params := make([]interface{}, len(memberships)+2) + params[0] = roomID + params[1] = excludeUserID + for k, v := range memberships { + params[k+2] = v + } + + query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = s.db.Prepare(query) + } + if err != nil { + return []string{}, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index e2dbcd5c8..19450099a 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -64,6 +67,7 @@ type inviteEventsStatements struct { selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 7e54fac17..2cc46a10a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,11 +18,9 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -77,6 +72,9 @@ SELECT event_id FROM AND ($4 IS NULL OR t.membership <> $4) ` +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt @@ -84,6 +82,7 @@ type membershipsStatements struct { //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -99,7 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, - // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, }.Prepare(db) } @@ -131,39 +130,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.PrepareContext(ctx, stmtSQL) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed") - params := []interface{}{ - roomID, userID, - } - for _, membership := range memberships { - params = append(params, membership) - } - - stmt = sqlutil.TxStmt(txn, stmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, params...) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. @@ -181,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 6242898e1..af2b2c074 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } @@ -47,6 +48,7 @@ type notificationDataStatements struct { streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt //selectUserUnreadCountsForRooms *sql.Stmt } @@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { @@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1aa4bfff7..23bc68a41 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -130,6 +133,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt //selectSearchStmt *sql.Stmt - prepared at runtime } @@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -363,9 +368,9 @@ func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var query string if onlySyncEvents { query = selectRecentEventsForSyncSQL @@ -373,49 +378,55 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( query = selectRecentEventsSQL } - stmt, params, err := prepareWithFilters( - s.db, txn, query, - []interface{}{ - roomID, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, - ) - if err != nil { - return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, false, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + result := make(map[string]types.RecentEvents, len(roomIDs)) + for _, roomID := range roomIDs { + stmt, params, err := prepareWithFilters( + s.db, txn, query, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + res := types.RecentEvents{} + // we queried for 1 more than the limit, so if we returned one more mark limited=true + if len(events) > eventFilter.Limit { + res.Limited = true + // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. + if chronologicalOrder { + events = events[1:] + } else { + events = events[:len(events)-1] + } + } + res.Events = events + result[roomID] = res } - return events, limited, nil + + return result, nil } func (s *outputRoomEventsStatements) SelectEarlyEvents( @@ -666,6 +677,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ return } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { params := make([]interface{}, len(types)) for i := range types { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 81b264988..dc698de2d 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -61,25 +62,24 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -90,28 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // insertEventInTopology inserts the given event in the room's topology, based @@ -183,10 +170,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( return } -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 4ef51b103..5d5200abc 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { @@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index fe6b3ce84..0373b0616 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,12 +17,14 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id IN ($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, + userIDs []string, +) ([]*types.PresenceInternal, error) { + qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + prepStmt, err := p.db.Prepare(qry) + if err != nil { + return nil, err } - stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed") + + params := make([]interface{}, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + result := make([]*types.PresenceInternal, 0, len(userIDs)) + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index a4a9b4395..ca3d80fb4 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -58,12 +58,16 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { @@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } // UpsertReceipt creates new user receipts @@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 74f4c830f..ef1b9b376 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "reflect" "testing" @@ -14,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" ) var ctx = context.Background() @@ -154,12 +156,12 @@ func TestRecentEventsPDU(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { var filter gomatrixserverlib.RoomEventFilter - var gotEvents []types.StreamEvent + var gotEvents map[string]types.RecentEvents var limited bool filter.Limit = tc.Limit WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { var err error - gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ + gotEvents, err = snapshot.RecentEvents(ctx, []string{r.ID}, types.Range{ From: tc.From, To: tc.To, }, &filter, !tc.ReverseOrder, true) @@ -167,15 +169,18 @@ func TestRecentEventsPDU(t *testing.T) { st.Fatalf("failed to do sync: %s", err) } }) + streamEvents := gotEvents[r.ID] + limited = streamEvents.Limited if limited != tc.WantLimited { st.Errorf("got limited=%v want %v", limited, tc.WantLimited) } - if len(gotEvents) != len(tc.WantEvents) { + if len(streamEvents.Events) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } - for j := range gotEvents { - if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { - st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + + for j := range streamEvents.Events { + if !reflect.DeepEqual(streamEvents.Events[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(streamEvents.Events[j].JSON()), string(tc.WantEvents[j].JSON())) } } }) @@ -198,10 +203,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { _ = MustWriteEvents(t, db, events) WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { - from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} t.Logf("max topo pos = %+v", from) // head towards the beginning of time to := types.TopologyToken{} @@ -218,6 +220,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { }) } +func TestStreamToTopologicalPosition(t *testing.T) { + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + + testCases := []struct { + name string + roomID string + streamPos types.StreamPosition + backwardOrdering bool + wantToken types.TopologyToken + }{ + { + name: "forward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "forward ordering not found streamPos returns max position", + roomID: r.ID, + streamPos: 100, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + { + name: "backward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "backward ordering not found streamPos returns maxDepth with param pduPosition", + roomID: r.ID, + streamPos: 100, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100}, + }, + { + name: "backward non-existent room returns zero token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1}, + }, + { + name: "forward non-existent room returns max token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + txn, err := db.NewDatabaseTransaction(ctx) + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + MustWriteEvents(t, db, r.Events()) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering) + if err != nil { + t.Fatal(err) + } + if tc.wantToken != token { + t.Fatalf("expected token %q, got %q", tc.wantToken, token) + } + }) + } + + }) +} + /* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent @@ -664,3 +748,239 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ return &tok } */ + +func pointer[t any](s t) *t { + return &s +} + +func TestRoomSummary(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + // Create some dummy users + moreUsers := []*test.User{} + moreUserIDs := []string{} + for i := 0; i < 10; i++ { + u := test.NewUser(t) + moreUsers = append(moreUsers, u) + moreUserIDs = append(moreUserIDs, u.ID) + } + + testCases := []struct { + name string + wantSummary *types.Summary + additionalEvents func(t *testing.T, room *test.Room) + }{ + { + name: "after initial creation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}}, + }, + { + name: "invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "invited user, but declined", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "joined user after invitation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited/left user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "leaving user after joining", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "many users", // heroes ordered by stream id + wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, + additionalEvents: func(t *testing.T, room *test.Room) { + for _, x := range moreUsers { + room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(x.ID)) + } + }, + }, + { + name: "canonical alias set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + "alias": "myalias", + }, test.WithStateKey("")) + }, + }, + { + name: "room name set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + "name": "my room name", + }, test.WithStateKey("")) + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + r := test.NewRoom(t, alice) + + if tc.additionalEvents != nil { + tc.additionalEvents(t, r) + } + + // write the room before creating a transaction + MustWriteEvents(t, db, r.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID) + assert.NoError(t, err) + assert.Equal(t, tc.wantSummary, summary) + }) + } + }) +} + +func TestRecentEvents(t *testing.T) { + alice := test.NewUser(t) + room1 := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice) + roomIDs := []string{room1.ID, room2.ID} + rooms := map[string]*test.Room{ + room1.ID: room1, + room2.ID: room2, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + filter := gomatrixserverlib.DefaultRoomEventFilter() + db, close, closeBase := MustCreateDatabase(t, dbType) + t.Cleanup(func() { + close() + closeBase() + }) + + MustWriteEvents(t, db, room1.Events()) + MustWriteEvents(t, db, room2.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + // get all recent events from 0 to 100 (we only created 5 events, so we should get 5 back) + roomEvs, err := transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for _, recentEvents := range roomEvs { + assert.Equal(t, 5, len(recentEvents.Events), "unexpected recent events for room") + } + + // update the filter to only return one event + filter.Limit = 1 + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + + // not chronologically ordered still returns the events in order (given ORDER BY id DESC) + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, false, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + }) +} diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index 23287c500..c7af4f977 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" ) func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { @@ -79,6 +80,9 @@ func TestCurrentRoomStateTable(t *testing.T) { return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id) } } + + testCurrentState(t, ctx, txn, tab, room) + return nil }) if err != nil { @@ -86,3 +90,39 @@ func TestCurrentRoomStateTable(t *testing.T) { } }) } + +func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) { + t.Run("test currentState", func(t *testing.T) { + // returns the complete state of the room with a default filter + filter := gomatrixserverlib.DefaultStateFilter() + evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + expectCount := 5 + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // When lazy loading, we expect no membership event, so only 4 events + filter.LazyLoadMembers = true + expectCount = 4 + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // same as above, but with existing NotTypes defined + notTypes := []string{gomatrixserverlib.MRoomMember} + filter.NotTypes = ¬Types + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + }) + +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index e027cf59e..af986ccb0 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,7 @@ type Invites interface { // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } type Peeks interface { @@ -48,6 +49,7 @@ type Peeks interface { SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error } type Events interface { @@ -64,7 +66,7 @@ type Events interface { // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. - SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) @@ -75,6 +77,8 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + + PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } @@ -91,10 +95,9 @@ type Topology interface { SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) - // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. - SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) + PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error } type CurrentRoomState interface { @@ -115,6 +118,9 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) + + SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) + SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -145,6 +151,7 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error } // SendToDevice tracks send-to-device messages which are sent to individual @@ -180,13 +187,14 @@ type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) - SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, @@ -198,6 +206,7 @@ type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) + PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error } type Ignores interface { @@ -207,7 +216,7 @@ type Ignores interface { type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) - GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) + GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 0cee7f5a5..df593ae78 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -3,8 +3,6 @@ package tables_test import ( "context" "database/sql" - "reflect" - "sort" "testing" "time" @@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) { testUpsert(t, ctx, table, userEvents[0], alice, room) testMembershipCount(t, ctx, table, room) - testHeroes(t, ctx, table, alice, room, users) }) } -func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { - - // Re-slice and sort the expected users - users = users[1:] - sort.Strings(users) - type testCase struct { - name string - memberships []string - wantHeroes []string - } - - testCases := []testCase{ - {name: "no memberships queried", memberships: []string{}}, - {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) - if err != nil { - t.Fatalf("unable to select heroes: %s", err) - } - if gotLen := len(got); gotLen != len(tc.wantHeroes) { - t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) - } - - if !reflect.DeepEqual(got, tc.wantHeroes) { - t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) - } - }) - } -} - func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go new file mode 100644 index 000000000..dce0c695a --- /dev/null +++ b/syncapi/storage/tables/presence_table_test.go @@ -0,0 +1,136 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Presence + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresPresenceTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqlitePresenceTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, close +} + +func TestPresence(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + ctx := context.Background() + + statusMsg := "Hello World!" + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + + var txn *sql.Tx + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustPresenceTable(t, dbType) + defer closeDB() + + // Insert some presences + pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos := types.StreamPosition(1) + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos = 2 + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + + // verify the expected max presence ID + maxPos, err := tab.GetMaxPresenceID(ctx, txn) + if err != nil { + t.Error(err) + } + if maxPos != wantPos { + t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos) + } + + // This should increment the position + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true) + if err != nil { + t.Error(err) + } + wantPos = pos + if wantPos <= maxPos { + t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos) + } + + // This should return only Bobs status + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + if err != nil { + t.Error(err) + } + + if c := len(presences); c > 1 { + t.Errorf("expected only one presence, got %d", c) + } + + // Validate the response + wantPresence := &types.PresenceInternal{ + UserID: bob.ID, + Presence: types.PresenceOnline, + StreamPos: wantPos, + LastActiveTS: timestamp, + ClientFields: types.PresenceClientResponse{ + LastActiveAgo: 0, + Presence: types.PresenceOnline.String(), + StatusMsg: &statusMsg, + }, + } + if !reflect.DeepEqual(wantPresence, presences[bob.ID]) { + t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence) + } + + // Try getting presences for existing and non-existing users + getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"} + presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers) + if err != nil { + t.Error(err) + } + + if len(presencesForUsers) >= len(getUsers) { + t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers)) + } + }) + +} diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 7996c2038..e8189c352 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -3,17 +3,17 @@ package streams import ( "context" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type DeviceListStreamProvider struct { DefaultStreamProvider - rsAPI api.SyncRoomserverAPI - keyAPI keyapi.SyncKeyAPI + rsAPI api.SyncRoomserverAPI + userAPI userapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( @@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from } - err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response) + err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 3d6e7a770..a65a64133 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "sort" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -14,11 +13,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - - "github.com/matrix-org/dendrite/syncapi/notifier" ) // The max number of per-room goroutines to have running. @@ -86,19 +83,24 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("unable to update event filter with ignored users") } - // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars - // TODO: This might be inefficient, when joined to many and/or large rooms. + recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true) + if err != nil { + return from + } + // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { + events := recentEvents[roomID] + // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars + // TODO: This might be inefficient, when joined to many and/or large rooms. joinedUsers := p.notifier.JoinedUsers(roomID) for _, sharedUser := range joinedUsers { p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser) } - } - // Build up a /sync response. Add joined rooms. - for _, roomID := range joinedRoomIDs { + // get the join response for each room jr, jerr := p.getJoinResponseForCompleteSync( - ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false, + events.Events, events.Limited, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -117,11 +119,25 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("p.DB.PeeksInRange failed") return from } - for _, peek := range peeks { - if !peek.Deleted { + if len(peeks) > 0 { + peekRooms := make([]string, 0, len(peeks)) + for _, peek := range peeks { + if !peek.Deleted { + peekRooms = append(peekRooms, peek.RoomID) + } + } + + recentEvents, err = snapshot.RecentEvents(ctx, peekRooms, r, &eventFilter, true, true) + if err != nil { + return from + } + + for _, roomID := range peekRooms { var jr *types.JoinResponse + events := recentEvents[roomID] jr, err = p.getJoinResponseForCompleteSync( - ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true, + events.Events, events.Limited, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -130,7 +146,7 @@ func (p *PDUStreamProvider) CompleteSync( } continue } - req.Response.Rooms.Peek[peek.RoomID] = jr + req.Response.Rooms.Peek[roomID] = jr } } @@ -232,7 +248,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( stateFilter *gomatrixserverlib.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { - + var err error originalLimit := eventFilter.Limit // If we're going backwards, grep at least X events, this is mostly to satisfy Sytest if r.Backwards && originalLimit < recentEventBackwardsLimit { @@ -243,8 +259,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, delta.RoomID, r, + dbEvents, err := snapshot.RecentEvents( + ctx, []string{delta.RoomID}, r, eventFilter, true, true, ) if err != nil { @@ -253,6 +269,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } + + recentStreamEvents := dbEvents[delta.RoomID].Events + limited := dbEvents[delta.RoomID].Limited + recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( snapshot.StreamEventsToEvents(device, recentStreamEvents), gomatrixserverlib.TopologicalOrderByPrevEvents, @@ -341,7 +361,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) + jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -386,19 +409,32 @@ func applyHistoryVisibilityFilter( roomID, userID string, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - // We need to make sure we always include the latest states events, if they are in the timeline. - // We grep at least limit * 2 events, to ensure we really get the needed events. - filter := gomatrixserverlib.DefaultStateFilter() - stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) - if err != nil { - // Not a fatal error, we can continue without the stateEvents, - // they are only needed if there are state events in the timeline. - logrus.WithError(err).Warnf("Failed to get current room state for history visibility") + // We need to make sure we always include the latest state events, if they are in the timeline. + alwaysIncludeIDs := make(map[string]struct{}) + var stateTypes []string + var senders []string + for _, ev := range recentEvents { + if ev.StateKey() != nil { + stateTypes = append(stateTypes, ev.Type()) + senders = append(senders, ev.Sender()) + } } - alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) - for _, ev := range stateEvents { - alwaysIncludeIDs[ev.EventID()] = struct{}{} + + // Only get the state again if there are state events in the timeline + if len(stateTypes) > 0 { + filter := gomatrixserverlib.DefaultStateFilter() + filter.Types = &stateTypes + filter.Senders = &senders + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) + if err != nil { + return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err) + } + + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } } + startTime := time.Now() events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { @@ -409,72 +445,24 @@ func applyHistoryVisibilityFilter( "room_id": roomID, "before": len(recentEvents), "after": len(events), - }).Trace("Applied history visibility (sync)") + }).Debugf("Applied history visibility (sync)") return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { - // Work out how many members are in the room. - joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) - - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount - - fetchStates := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomName}, - {EventType: gomatrixserverlib.MRoomCanonicalAlias}, - } - // Check if the room has a name or a canonical alias - latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{} - err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState) - if err != nil { - return - } - // Check if the room has a name or canonical alias, if so, return. - for _, ev := range latestState.StateEvents { - switch ev.Type() { - case gomatrixserverlib.MRoomName: - if gjson.GetBytes(ev.Content(), "name").Str != "" { - return - } - case gomatrixserverlib.MRoomCanonicalAlias: - if gjson.GetBytes(ev.Content(), "alias").Str != "" { - return - } - } - } - heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) - if err != nil { - return - } - sort.Strings(heroes) - jr.Summary.Heroes = heroes -} - func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, - r types.Range, stateFilter *gomatrixserverlib.StateFilter, - eventFilter *gomatrixserverlib.RoomEventFilter, wantFullState bool, device *userapi.Device, isPeek bool, + recentStreamEvents []types.StreamEvent, + limited bool, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, roomID, r, eventFilter, true, true, - ) - if err != nil { - if err == sql.ErrNoRows { - return jr, nil - } - return - } // Work our way through the timeline events and pick out the event IDs // of any state events that appear in the timeline. We'll specifically @@ -495,7 +483,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) + jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index bbd12be0d..c6c8d6866 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,6 +17,7 @@ package streams import ( "context" "encoding/json" + "fmt" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" @@ -67,39 +68,25 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - if len(presences) == 0 { + getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences) + if err != nil { + req.Log.WithError(err).Error("getNeededUsersFromRequest failed") + return from + } + + // Got no presence between range and no presence to get from the database + if len(getPresenceForUsers) == 0 && len(presences) == 0 { return to } - // add newly joined rooms user presences - newlyJoined := joinedRooms(req.Response, req.Device.UserID) - if len(newlyJoined) > 0 { - // TODO: Check if this is working better than before. - if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { - req.Log.WithError(err).Error("unable to refresh notifier lists") - return from - } - NewlyJoinedLoop: - for _, roomID := range newlyJoined { - roomUsers := p.notifier.JoinedUsers(roomID) - for i := range roomUsers { - // we already got a presence from this user - if _, ok := presences[roomUsers[i]]; ok { - continue - } - // Bear in mind that this might return nil, but at least populating - // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) - if err != nil { - req.Log.WithError(err).Error("unable to query presence for user") - _ = snapshot.Rollback() - return from - } - if len(presences) > req.Filter.Presence.Limit { - break NewlyJoinedLoop - } - } - } + dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers) + if err != nil { + req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() + return from + } + for _, presence := range dbPresences { + presences[presence.UserID] = presence } lastPos := from @@ -147,6 +134,39 @@ func (p *PresenceStreamProvider) IncrementalSync( return lastPos } +func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) { + getPresenceForUsers := []string{} + // Add presence for users which newly joined a room + for userID := range req.MembershipChanges { + if _, ok := presences[userID]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, userID) + } + + // add newly joined rooms user presences + newlyJoined := joinedRooms(req.Response, req.Device.UserID) + if len(newlyJoined) == 0 { + return getPresenceForUsers, nil + } + + // TODO: Check if this is working better than before. + if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { + return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err) + } + for _, roomID := range newlyJoined { + roomUsers := p.notifier.JoinedUsers(roomID) + for i := range roomUsers { + // we already got a presence from this user + if _, ok := presences[roomUsers[i]]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, roomUsers[i]) + } + } + return getPresenceForUsers, nil +} + func joinedRooms(res *types.Response, userID string) []string { var roomIDs []string for roomID, join := range res.Rooms.Join { diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 8cc028bdf..5c112ed45 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -5,7 +5,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" @@ -29,7 +28,7 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.SyncUserAPI, - rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI, + rsAPI rsapi.SyncRoomserverAPI, eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier, mrdb *mrd.Queries, ) *Streams { @@ -63,7 +62,7 @@ func NewSyncStreamProviders( DeviceListStreamProvider: &DeviceListStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 046913750..22ee340bb 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -48,7 +47,6 @@ type RequestPool struct { db storage.Database cfg *config.SyncAPI userAPI userapi.SyncUserAPI - keyAPI keyapi.SyncKeyAPI rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map Presence *sync.Map @@ -69,7 +67,7 @@ type PresenceConsumer interface { // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, - userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, + userAPI userapi.SyncUserAPI, rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool, @@ -83,7 +81,6 @@ func NewRequestPool( db: db, cfg: cfg, userAPI: userAPI, - keyAPI: keyAPI, rsAPI: rsAPI, lastseen: &sync.Map{}, Presence: &sync.Map{}, @@ -145,12 +142,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) + dbPresence, err := db.GetPresences(context.Background(), []string{userID}) if err != nil && err != sql.ErrNoRows { return } - if dbPresence != nil { - newPresence.ClientFields = dbPresence.ClientFields + if len(dbPresence) > 0 && dbPresence[0] != nil { + newPresence.ClientFields = dbPresence[0].ClientFields } newPresence.ClientFields.Presence = presenceID.String() @@ -159,17 +156,8 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user existingPresence, ok := rp.Presence.LoadOrStore(userID, newPresence) if ok { p := existingPresence.(types.PresenceInternal) - if dbPresence != nil { - if p.Presence == newPresence.Presence && newPresence.LastActiveTS-dbPresence.LastActiveTS < types.PresenceNoOpMs { - return - } - if dbPresence.Presence == types.PresenceOnline && presenceID == types.PresenceOnline && newPresence.LastActiveTS-dbPresence.LastActiveTS >= types.PresenceNoOpMs { - err := db.UpdateLastActive(context.Background(), userID, uint64(newPresence.LastActiveTS)) - if err != nil { - logrus.WithError(err).Error("failed to update last active") - } - return - } + if p.ClientFields.Presence == newPresence.ClientFields.Presence { + return } } @@ -293,7 +281,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 // Only try to get OTKs if the context isn't already done. if syncReq.Context.Err() == nil { - err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) + err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } @@ -581,7 +569,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index cdc658331..f14a6cb57 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -32,8 +32,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ return 0, nil } -func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return &types.PresenceInternal{}, nil +func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) { + return []*types.PresenceInternal{}, nil } func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 8a7216228..c6b1f4190 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -44,7 +43,6 @@ func AddPublicRoutes( base *base.BaseDendrite, userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, - keyAPI keyapi.SyncKeyAPI, ) { cfg := &base.Cfg.SyncAPI @@ -69,7 +67,7 @@ func AddPublicRoutes( eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier, mrq) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier, mrq) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") @@ -85,7 +83,7 @@ func AddPublicRoutes( userAPI, ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start presence consumer") @@ -131,7 +129,7 @@ func AddPublicRoutes( } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 483274481..1226b02b6 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -14,8 +14,10 @@ import ( "github.com/nats-io/nats.go" "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/syncapi/routing" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/clientapi/producers" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -82,21 +84,18 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc return nil } +func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { + return nil +} + +func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { + return nil +} + func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { return nil } -type syncKeyAPI struct { - keyapi.SyncKeyAPI -} - -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { - return nil -} -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { - return nil -} - func TestSyncAPIAccessTokens(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { testSyncAccessTokens(t, dbType) @@ -120,7 +119,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -219,7 +218,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -303,7 +302,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -421,7 +420,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -448,6 +447,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ "access_token": bobDev.AccessToken, "dir": "b", + "filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility }))) if w.Code != 200 { t.Logf("%s", w.Body.String()) @@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr } } +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + bob := test.NewUser(t) + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "notjoinedtoanyrooms", + } + + testCases := []struct { + name string + roomID string + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room) *http.Request + wantOK bool + wantMemberCount int + useSleep bool // :/ + }{ + { + name: "/members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + }, + { + name: "Alice leaves before Bob joins, should not be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "Alice leaves after Bob joins, should be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 2, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: false, + }, + { + name: "'at' specified, returns memberships before Bob joins", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "at": "t2_5", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "'membership=leave' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "membership": "leave", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=join' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "join", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "leave", + "membership": "join", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "non-existent room ID", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: false, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // wait for the events to come down sync + if tc.useSleep { + time.Sleep(time.Millisecond * 100) + } else { + syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + } + + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } @@ -541,7 +787,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -659,6 +905,261 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } +func TestContext(t *testing.T) { + test.WithAllDatabases(t, testContext) +} + +func testContext(t *testing.T, dbType test.DBType) { + + tests := []struct { + name string + roomID string + eventID string + params map[string]string + wantError bool + wantStateLength int + wantBeforeLength int + wantAfterLength int + }{ + { + name: "invalid filter", + params: map[string]string{ + "filter": "{", + }, + wantError: true, + }, + { + name: "invalid limit", + params: map[string]string{ + "limit": "abc", + }, + wantError: true, + }, + { + name: "high limit", + params: map[string]string{ + "limit": "100000", + }, + }, + { + name: "fine limit", + params: map[string]string{ + "limit": "10", + }, + }, + { + name: "last event without lazy loading", + wantStateLength: 5, + }, + { + name: "last event with lazy loading", + params: map[string]string{ + "filter": `{"lazy_load_members":true}`, + }, + wantStateLength: 1, + }, + { + name: "invalid room", + roomID: "!doesnotexist", + wantError: true, + }, + { + name: "invalid eventID", + eventID: "$doesnotexist", + wantError: true, + }, + { + name: "state is limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + }, + { + name: "events are not limited", + wantBeforeLength: 7, + }, + { + name: "all events are limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + wantBeforeLength: 1, + wantAfterLength: 1, + }, + } + + user := test.NewUser(t) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI) + + room := test.NewRoom(t, user) + + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 1!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 2!"}) + thirdMsg := room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world3!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world4!"}) + + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID()) + return gjson.Get(syncBody, path).Exists() + }) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + params := map[string]string{ + "access_token": alice.AccessToken, + } + w := httptest.NewRecorder() + // test overrides + roomID := room.ID + if tc.roomID != "" { + roomID = tc.roomID + } + eventID := thirdMsg.EventID() + if tc.eventID != "" { + eventID = tc.eventID + } + requestPath := fmt.Sprintf("/_matrix/client/v3/rooms/%s/context/%s", roomID, eventID) + if tc.params != nil { + for k, v := range tc.params { + params[k] = v + } + } + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params))) + + if tc.wantError && w.Code == 200 { + t.Fatalf("Expected an error, but got none") + } + t.Log(w.Body.String()) + resp := routing.ContextRespsonse{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if tc.wantStateLength > 0 && tc.wantStateLength != len(resp.State) { + t.Fatalf("expected %d state events, got %d", tc.wantStateLength, len(resp.State)) + } + if tc.wantBeforeLength > 0 && tc.wantBeforeLength != len(resp.EventsBefore) { + t.Fatalf("expected %d before events, got %d", tc.wantBeforeLength, len(resp.EventsBefore)) + } + if tc.wantAfterLength > 0 && tc.wantAfterLength != len(resp.EventsAfter) { + t.Fatalf("expected %d after events, got %d", tc.wantAfterLength, len(resp.EventsAfter)) + } + + if !tc.wantError && resp.Event.EventID != eventID { + t.Fatalf("unexpected eventID %s, expected %s", resp.Event.EventID, eventID) + } + }) + } +} + +func TestUpdateRelations(t *testing.T) { + testCases := []struct { + name string + eventContent map[string]interface{} + eventType string + }{ + { + name: "empty event content should not error", + }, + { + name: "unable to unmarshal event should not error", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": map[string]interface{}{}, // this should be a string and not struct + }, + }, + }, + { + name: "empty event ID is ignored", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "", + }, + }, + }, + { + name: "empty rel_type is ignored", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "", + }, + }, + }, + { + name: "redactions are ignored", + eventType: gomatrixserverlib.MRoomRedaction, + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "m.replace", + }, + }, + }, + { + name: "valid event is correctly written", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "m.replace", + }, + }, + }, + } + + ctx := context.Background() + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, shutdownBase := testrig.CreateBaseDendrite(t, dbType) + t.Cleanup(shutdownBase) + db, _, err := storage.NewSyncServerDatasource(base, &base.Cfg.SyncAPI.Database) + if err != nil { + t.Fatal(err) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + evType := "m.room.message" + if tc.eventType != "" { + evType = tc.eventType + } + ev := room.CreateEvent(t, alice, evType, tc.eventContent) + err = db.UpdateRelations(ctx, ev) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func syncUntil(t *testing.T, base *base.BaseDendrite, accessToken string, skip bool, diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c4c7b39fb..a2a0b9fde 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -64,6 +64,11 @@ type StreamEvent struct { ExcludeFromSync bool } +type RecentEvents struct { + Limited bool + Events []StreamEvent +} + // Range represents a range between two stream positions. type Range struct { // From is the position the client has already received. diff --git a/sytest-blacklist b/sytest-blacklist index 74e670d0c..e7dfcfb0f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,52 +1,22 @@ -# Relies on a rejected PL event which will never be accepted into the DAG - -# Caused by - -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -# We don't implement lazy membership loading yet - +# Blacklisted due to https://github.com/matrix-org/matrix-spec/issues/942 The only membership state included in a gapped incremental sync is for senders in the timeline -# Blacklisted out of flakiness after #1479 - -Invited user can reject local invite after originator leaves -Invited user can reject invite for empty room -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes - -# Blacklisted due to flakiness - -Forgotten room messages cannot be paginated - -# Blacklisted due to flakiness after #1774 - -Local device key changes get to remote servers with correct prev_id - -# we don't support groups - -Remove group category -Remove group role - # Flakey - AS-ghosted users can use rooms themselves AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server - -# More flakey - -Guest users can join guest_access rooms +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # This will fail in HTTP API mode, so blacklisted for now - If a device list update goes missing, the server resyncs on the next one # Might be a bug in the test because leaves do appear :-( - Leaves are present in non-gapped incremental syncs +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation # Below test was passing for the wrong reason, failing correctly since #2858 New federated private chats get full presence information (SYN-115) If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index 75a2da635..2bb0aa433 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -764,3 +764,23 @@ local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation +User can create and send/receive messages in a room with version 10 +local user can join room with version 10 +User can invite local user to room with version 10 +remote user can join room with version 10 +User can invite remote user to room with version 10 +Remote user can backfill in a room with version 10 +Can reject invites over federation for rooms with version 10 +Can receive redactions from regular users over federation in room version 10 +New federated private chats get full presence information (SYN-115) +/state returns M_NOT_FOUND for an outlier +/state_ids returns M_NOT_FOUND for an outlier +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Guest users can join guest_access rooms +Forgotten room messages cannot be paginated +Local device key changes get to remote servers with correct prev_id diff --git a/test/db.go b/test/db.go index 3de3d267f..1b0ce6751 100644 --- a/test/db.go +++ b/test/db.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "testing" "github.com/lib/pq" @@ -100,16 +101,12 @@ func currentUser() string { // Returns the connection string to use and a close function which must be called when the test finishes. // Calling this function twice will return the same database, which will have data from previous tests // unless close() is called. -// TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { - // this will be made in the current working directory which namespaces concurrent package runs correctly - dbname := "dendrite_test.db" + // this will be made in the t.TempDir, which is unique per test + dbname := filepath.Join(t.TempDir(), "dendrite_test.db") return fmt.Sprintf("file:%s", dbname), func() { - err := os.Remove(dbname) - if err != nil { - t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) - } + t.Cleanup(func() {}) // removes the t.TempDir } } @@ -175,7 +172,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { - //tt.Parallel() + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go new file mode 100644 index 000000000..de0dc54eb --- /dev/null +++ b/test/memory_federation_db.go @@ -0,0 +1,511 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var nidMutex sync.Mutex +var nid = int64(0) + +type InMemoryFederationDatabase struct { + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName +} + +func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { + return &InMemoryFederationDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + } +} + +func (d *InMemoryFederationDatabase) StoreJSON( + ctx context.Context, + js string, +) (*receipt.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingPDUs[&newReceipt] = &event + return &newReceipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingEDUs[&newReceipt] = &edu + return &newReceipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *InMemoryFederationDatabase) GetPendingPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedPDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, + eduType string, + expireEDUTypes map[string]time.Duration, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedEDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + knownRelayServers := []gomatrixserverlib.ServerName{} + if relayServers, ok := d.relayServers[serverName]; ok { + knownRelayServers = relayServers + } + + return knownRelayServers, nil +} + +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + alreadyKnown := false + for _, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + alreadyKnown = true + } + } + if !alreadyKnown { + d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + for i, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + d.relayServers[serverName] = append( + d.relayServers[serverName][:i], + d.relayServers[serverName][i+1:]..., + ) + break + } + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) FetcherName() string { + return "" +} + +func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} + +func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error { + return nil +} diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go new file mode 100644 index 000000000..db93919df --- /dev/null +++ b/test/memory_relay_db.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type InMemoryRelayDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { + return &InMemoryRelayDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *InMemoryRelayDatabase) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *InMemoryRelayDatabase) InsertQueueJSON( + ctx context.Context, + txn *sql.Tx, + json string, +) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueJSON( + ctx context.Context, + txn *sql.Tx, + nids []int64, +) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueJSON( + ctx context.Context, + txn *sql.Tx, + jsonNIDs []int64, +) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/test/room.go b/test/room.go index 4328bf84f..685876cb0 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/test/testrig/base.go b/test/testrig/base.go index 15fb5c370..2d101ec00 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -15,41 +15,37 @@ package testrig import ( - "errors" "fmt" - "io/fs" - "os" - "strings" + "path/filepath" "testing" - "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { var cfg config.Dendrite cfg.Defaults(config.DefaultOpts{ - Generate: false, - Monolithic: true, + Generate: false, + SingleDatabase: true, }) cfg.Global.JetStream.InMemory = true cfg.FederationAPI.KeyPerspectives = nil switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.MediaAPI.Defaults(config.DefaultOpts{ // autogen a media path - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.SyncAPI.Fulltext.Defaults(config.DefaultOpts{ // use in memory fts - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: true, }) cfg.Global.ServerName = "test" // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use @@ -62,34 +58,38 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close + base := base.NewBaseDendrite(&cfg, base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + close() + } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: false, // because we need a database per component + Generate: true, + SingleDatabase: false, }) cfg.Global.ServerName = "test" + // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { - // cleanup db files. This risks getting out of sync as we add more database strings :( - dbFiles := []config.DataSource{ - cfg.FederationAPI.Database.ConnectionString, - cfg.KeyServer.Database.ConnectionString, - cfg.MSCs.Database.ConnectionString, - cfg.MediaAPI.Database.ConnectionString, - cfg.RoomServer.Database.ConnectionString, - cfg.SyncAPI.Database.ConnectionString, - cfg.UserAPI.AccountDatabase.ConnectionString, - } - for _, fileURI := range dbFiles { - path := strings.TrimPrefix(string(fileURI), "file:") - err := os.Remove(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("failed to cleanup sqlite db '%s': %s", fileURI, err) - } - } + + // Use a temp dir provided by go for tests, this will be cleanup by a call to t.CleanUp() + tempDir := t.TempDir() + cfg.FederationAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "federationapi.db")) + cfg.KeyServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "keyserver.db")) + cfg.MSCs.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mscs.db")) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mediaapi.db")) + cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + + base := base.NewBaseDendrite(&cfg, base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) @@ -101,14 +101,14 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat if cfg == nil { cfg = &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, + Generate: true, + SingleDatabase: false, }) } cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, "Tests") + base := base.NewBaseDendrite(cfg, base.DisableMetrics) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc } diff --git a/test/user.go b/test/user.go index 692eae351..95a8f83e6 100644 --- a/test/user.go +++ b/test/user.go @@ -47,7 +47,7 @@ var ( type User struct { ID string - accountType api.AccountType + AccountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve func WithAccountType(accountType api.AccountType) UserOpt { return func(u *User) { - u.accountType = accountType + u.AccountType = accountType } } diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc8..fa297f773 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -15,9 +15,13 @@ package api import ( + "bytes" "context" "encoding/json" + "strings" + "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -26,15 +30,12 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - AppserviceUserAPI SyncUserAPI ClientUserAPI - MediaUserAPI FederationUserAPI - RoomserverUserAPI - KeyserverUserAPI QuerySearchProfilesAPI // used by p2p demos + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the appservice api @@ -43,13 +44,9 @@ type AppserviceUserAPI interface { PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error } -type KeyserverUserAPI interface { - QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error -} - type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -59,13 +56,20 @@ type MediaUserAPI interface { // api functions required by the federation api type FederationUserAPI interface { + UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -78,6 +82,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + ClientKeyAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -671,3 +676,319 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} + +// API functions required by the clientapi +type ClientKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type UploadDeviceKeysAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error +} + +// API functions required by the syncapi +type SyncKeyAPI interface { + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type FederationKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error +} + +// KeyError is returned if there was a problem performing/querying the server +type KeyError struct { + Err string `json:"error"` + IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE + IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM + IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM +} + +func (k *KeyError) Error() string { + return k.Err +} + +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` + *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` + // A monotonically increasing number which represents device changes for this user. + StreamID int64 + DeviceChangeID int64 +} + +// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log +type OutputCrossSigningKeyUpdate struct { + CrossSigningKeyUpdate `json:"signing_keys"` +} + +type CrossSigningKeyUpdate struct { + MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` +} + +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + +// DeviceKeys represents a set of device keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type DeviceKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // The device display name + DisplayName string + // The raw device key JSON + KeyJSON []byte +} + +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { + return DeviceMessage{ + DeviceKeys: k, + StreamID: streamID, + } +} + +// OneTimeKeys represents a set of one-time keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type OneTimeKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + +// OneTimeKeysCount represents the counts of one-time keys for a single device +type OneTimeKeysCount struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // algorithm to count e.g: + // { + // "curve25519": 10, + // "signed_curve25519": 20 + // } + KeyCount map[string]int +} + +// PerformUploadKeysRequest is the request to PerformUploadKeys +type PerformUploadKeysRequest struct { + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool +} + +// PerformUploadKeysResponse is the response to PerformUploadKeys +type PerformUploadKeysResponse struct { + // A fatal error when processing e.g database failures + Error *KeyError + // A map of user_id -> device_id -> Error for tracking failures. + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount +} + +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + +type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration +} + +type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action + Error *KeyError +} + +type PerformUploadDeviceKeysRequest struct { + gomatrixserverlib.CrossSigningKeys + // The user that uploaded the key, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceKeysResponse struct { + Error *KeyError +} + +type PerformUploadDeviceSignaturesRequest struct { + Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + // The user that uploaded the sig, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceSignaturesResponse struct { + Error *KeyError +} + +type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration +} + +type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + +type QueryKeyChangesRequest struct { + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 + // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + ToOffset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int64 + Devices []DeviceMessage + Error *KeyError +} + +type QuerySignaturesRequest struct { + // A map of target user ID -> target key/device IDs to retrieve signatures for + TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` +} + +type QuerySignaturesResponse struct { + // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures + Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing user-signing key + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError +} + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go deleted file mode 100644 index ce661770f..000000000 --- a/userapi/api/api_trace.go +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/matrix-org/util" -) - -// UserInternalAPITrace wraps a RoomserverInternalAPI and logs the -// complete request/response/error -type UserInternalAPITrace struct { - Impl UserInternalAPI -} - -func (t *UserInternalAPITrace) InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error { - err := t.Impl.InputAccountData(ctx, req, res) - util.GetLogger(ctx).Infof("InputAccountData req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error { - err := t.Impl.PerformAccountCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAccountCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error { - err := t.Impl.PerformPasswordUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPasswordUpdate req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error { - err := t.Impl.PerformDeviceCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error { - err := t.Impl.PerformDeviceDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceDeletion req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error { - err := t.Impl.PerformLastSeenUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLastSeenUpdate req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error { - err := t.Impl.PerformDeviceUpdate(ctx, req, res) - util.GetLogger(ctx).Infof("PerformDeviceUpdate req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error { - err := t.Impl.PerformAccountDeactivation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAccountDeactivation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error { - err := t.Impl.PerformOpenIDTokenCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformOpenIDTokenCreation req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error { - err := t.Impl.PerformKeyBackup(ctx, req, res) - util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { - err := t.Impl.PerformPusherSet(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { - err := t.Impl.PerformPusherDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { - err := t.Impl.PerformPushRulesPut(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error { - err := t.Impl.QueryKeyBackup(ctx, req, res) - util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { - err := t.Impl.QueryProfile(ctx, req, res) - util.GetLogger(ctx).Infof("QueryProfile req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error { - err := t.Impl.QueryAccessToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccessToken req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error { - err := t.Impl.QueryDevices(ctx, req, res) - util.GetLogger(ctx).Infof("QueryDevices req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error { - err := t.Impl.QueryAccountData(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountData req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error { - err := t.Impl.QueryDeviceInfos(ctx, req, res) - util.GetLogger(ctx).Infof("QueryDeviceInfos req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error { - err := t.Impl.QuerySearchProfiles(ctx, req, res) - util.GetLogger(ctx).Infof("QuerySearchProfiles req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error { - err := t.Impl.QueryOpenIDToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { - err := t.Impl.QueryPushers(ctx, req, res) - util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { - err := t.Impl.QueryPushRules(ctx, req, res) - util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) - return err -} -func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { - err := t.Impl.QueryNotifications(ctx, req, res) - util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error { - err := t.Impl.SetAvatarURL(ctx, req, res) - util.GetLogger(ctx).Infof("SetAvatarURL req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, req, res) - util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error { - err := t.Impl.QueryAccountAvailability(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountAvailability req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { - err := t.Impl.SetDisplayName(ctx, req, res) - util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error { - err := t.Impl.QueryAccountByPassword(ctx, req, res) - util.GetLogger(ctx).Infof("QueryAccountByPassword req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error { - err := t.Impl.QueryLocalpartForThreePID(ctx, req, res) - util.GetLogger(ctx).Infof("QueryLocalpartForThreePID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error { - err := t.Impl.QueryThreePIDsForLocalpart(ctx, req, res) - util.GetLogger(ctx).Infof("QueryThreePIDsForLocalpart req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error { - err := t.Impl.PerformForgetThreePID(ctx, req, res) - util.GetLogger(ctx).Infof("PerformForgetThreePID req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error { - err := t.Impl.PerformSaveThreePIDAssociation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformSaveThreePIDAssociation req=%+v res=%+v", js(req), js(res)) - return err -} - -func js(thing interface{}) string { - b, err := json.Marshal(thing) - if err != nil { - return fmt.Sprintf("Marshal error:%s", err) - } - return string(b) -} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go deleted file mode 100644 index e60dae594..000000000 --- a/userapi/api/api_trace_logintoken.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - - "github.com/matrix-org/util" -) - -func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { - err := t.Impl.PerformLoginTokenCreation(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { - err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) - util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) - return err -} - -func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { - err := t.Impl.QueryLoginToken(ctx, req, res) - util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) - return err -} diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e77..51bd2753a 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/keyserver/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go similarity index 97% rename from keyserver/consumers/devicelistupdate.go rename to userapi/consumers/devicelistupdate.go index cd911f8c6..a65889fcc 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -18,11 +18,11 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -41,7 +41,7 @@ type DeviceListUpdateConsumer struct { // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. func NewDeviceListUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 5d8924dda..47d330959 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, @@ -385,7 +385,6 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, - LocalOnly: true, } var res rsapi.QueryMembershipsForRoomResponse @@ -396,8 +395,23 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s } var members []*localMembership - var ntotal int for _, event := range res.JoinEvents { + // Filter out invalid join events + if event.StateKey == nil { + continue + } + if *event.StateKey == "" { + continue + } + _, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + log.WithError(err).Error("failed to get servername from statekey") + continue + } + // Only get memberships for our server + if serverName != s.serverName { + continue + } member, err := newLocalMembership(&event) if err != nil { log.WithError(err).Errorf("Parsing MemberContent") @@ -410,11 +424,10 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s continue } - ntotal++ members = append(members, member) } - return members, ntotal, nil + return members, len(res.JoinEvents), nil } // roomName returns the name in the event (if type==m.room.name), or @@ -641,7 +654,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if rule == nil { // SPEC: If no rules match an event, the homeserver MUST NOT // notify the Push Gateway for that event. - return nil, err + return nil, nil } log.WithFields(log.Fields{ diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 265e3a3aa..bc5ae652d 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,11 +18,11 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { @@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) { wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: false, - }, }, }, { @@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) { { Kind: pushrules.SetTweakAction, Tweak: pushrules.HighlightTweak, - Value: true, }, }, }, diff --git a/keyserver/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go similarity index 87% rename from keyserver/consumers/signingkeyupdate.go rename to userapi/consumers/signingkeyupdate.go index bcceaad15..f4ff017db 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -22,11 +22,10 @@ import ( "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. @@ -35,24 +34,24 @@ type SigningKeyUpdateConsumer struct { jetstream nats.JetStreamContext durable string topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + userAPI api.UploadDeviceKeysAPI + cfg *config.UserAPI isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. func NewSigningKeyUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, - keyAPI *internal.KeyInternalAPI, + userAPI api.UploadDeviceKeysAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ ctx: process.Context(), jetstream: js, durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, + userAPI: userAPI, cfg: cfg, isLocalServerName: cfg.Matrix.IsLocalServerName, } @@ -70,7 +69,7 @@ func (t *SigningKeyUpdateConsumer) Start() error { // signing key update events topic from the key server. func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called - var updatePayload keyapi.CrossSigningKeyUpdate + var updatePayload api.CrossSigningKeyUpdate if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { logrus.WithError(err).Errorf("Failed to read from signing key update input topic") return true @@ -94,12 +93,12 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if updatePayload.SelfSigningKey != nil { keys.SelfSigningKey = *updatePayload.SelfSigningKey } - uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ + uploadReq := &api.PerformUploadDeviceKeysRequest{ CrossSigningKeys: keys, UserID: updatePayload.UserID, } - uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { logrus.WithError(err).Error("failed to upload device keys") return false } diff --git a/keyserver/internal/cross_signing.go b/userapi/internal/cross_signing.go similarity index 91% rename from keyserver/internal/cross_signing.go rename to userapi/internal/cross_signing.go index 99859dff6..8b9704d1b 100644 --- a/keyserver/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -22,8 +22,8 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" @@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } // nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -169,7 +169,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P // something if any of the specified keys in the request are different // to what we've got in the database, to avoid generating key change // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), @@ -216,7 +216,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } @@ -234,7 +234,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P continue } for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } @@ -257,7 +257,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P if update.MasterKey == nil && update.SelfSigningKey == nil { return nil } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -266,7 +266,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P return nil } -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -342,7 +342,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req MasterKey: &masterKey, SelfSigningKey: &selfSigningKey, } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -352,7 +352,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req return nil } -func (a *KeyInternalAPI) processSelfSignatures( +func (a *UserInternalAPI) processSelfSignatures( ctx context.Context, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -373,7 +373,7 @@ func (a *KeyInternalAPI) processSelfSignatures( } for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -384,7 +384,7 @@ func (a *KeyInternalAPI) processSelfSignatures( case *gomatrixserverlib.DeviceKeys: for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -401,7 +401,7 @@ func (a *KeyInternalAPI) processSelfSignatures( return nil } -func (a *KeyInternalAPI) processOtherSignatures( +func (a *UserInternalAPI) processOtherSignatures( ctx context.Context, userID string, queryRes *api.QueryKeysResponse, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -442,7 +442,7 @@ func (a *KeyInternalAPI) processOtherSignatures( } for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -461,11 +461,11 @@ func (a *KeyInternalAPI) processOtherSignatures( return nil } -func (a *KeyInternalAPI) crossSigningKeysFromDatabase( +func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue @@ -478,7 +478,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( break } - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue @@ -522,9 +522,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( } } -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), @@ -556,7 +556,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign for _, targetKeyID := range forTargetUser { // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), diff --git a/keyserver/internal/device_list_update.go b/userapi/internal/device_list_update.go similarity index 94% rename from keyserver/internal/device_list_update.go rename to userapi/internal/device_list_update.go index 8ff9dfc31..3b4dcf98e 100644 --- a/keyserver/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -24,6 +24,8 @@ import ( "sync" "time" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -31,8 +33,8 @@ import ( "github.com/sirupsen/logrus" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) var ( @@ -102,6 +104,7 @@ type DeviceListUpdater struct { // block on or timeout via a select. userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error } type DeviceListUpdaterAPI interface { @@ -140,7 +145,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -154,6 +159,7 @@ func NewDeviceListUpdater( workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, } } @@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error { go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } @@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error { return nil } +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { u.mu.Lock() defer u.mu.Unlock() @@ -452,10 +477,6 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go return e.RetryAfter, err } else if e.Blacklisted { return time.Hour * 8, err - } else if e.Code >= 300 { - // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError - // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - return hourWaitTime, err } case net.Error: // Use the default waitTime, if it's a timeout. diff --git a/keyserver/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go similarity index 100% rename from keyserver/internal/device_list_update_default.go rename to userapi/internal/device_list_update_default.go diff --git a/keyserver/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go similarity index 100% rename from keyserver/internal/device_list_update_sytest.go rename to userapi/internal/device_list_update_sytest.go diff --git a/keyserver/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go similarity index 81% rename from keyserver/internal/device_list_update_test.go rename to userapi/internal/device_list_update_test.go index a374c9516..868fc9be8 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -29,8 +29,13 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/keyserver/api" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct { mu sync.Mutex // protect staleUsers } +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { @@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) { UserID: remoteUserID, } err := updater.Update(ctx, event) + if err != nil { t.Fatalf("Update returned an error: %s", err) } @@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) { t.Errorf("user %s is marked as stale", userID) } } + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/keyserver/internal/internal.go b/userapi/internal/key_api.go similarity index 83% rename from keyserver/internal/internal.go rename to userapi/internal/key_api.go index 9a08a0bb7..be816fe5d 100644 --- a/keyserver/internal/internal.go +++ b/userapi/internal/key_api.go @@ -29,29 +29,11 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -type KeyInternalAPI struct { - DB storage.Database - Cfg *config.KeyServer - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater -} - -func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - a.UserAPI = i -} - -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) +func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), @@ -63,7 +45,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC return nil } -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { +func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { res.KeyErrors = make(map[string]map[string]*api.KeyError) if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) @@ -71,7 +53,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } @@ -79,7 +61,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -97,11 +79,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC domainToDeviceKeys[string(serverName)] = nested } for domain, local := range domainToDeviceKeys { - if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), @@ -129,7 +111,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC return nil } -func (a *KeyInternalAPI) claimRemoteKeys( +func (a *UserInternalAPI) claimRemoteKeys( ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ) { var wg sync.WaitGroup // Wait for fan-out goroutines to finish @@ -146,7 +128,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -177,8 +159,8 @@ func (a *KeyInternalAPI) claimRemoteKeys( }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) } -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { +func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } @@ -186,8 +168,8 @@ func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) +func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), @@ -198,8 +180,8 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne return nil } -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) +func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -225,8 +207,8 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query // PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present // in our database. -func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) +func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) if err != nil { return err } @@ -244,7 +226,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap } // nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -262,8 +244,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if a.Cfg.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + if a.Config.Matrix.IsLocalServerName(serverName) { + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -276,8 +258,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for _, dk := range deviceKeys { dids = append(dids, dk.DeviceID) } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + var queryRes api.QueryDeviceInfosResponse + err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{ DeviceIDs: dids, }, &queryRes) if err != nil { @@ -341,14 +323,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} } for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -367,14 +349,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for targetUserID, forUserID := range res.DeviceKeys { for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -403,7 +385,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques return nil } -func (a *KeyInternalAPI) remoteKeysFromDatabase( +func (a *UserInternalAPI) remoteKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) @@ -429,7 +411,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( return fetchRemote } -func (a *KeyInternalAPI) queryRemoteKeys( +func (a *UserInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, ) { @@ -441,13 +423,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -499,7 +481,7 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) queryRemoteKeysOnServer( +func (a *UserInternalAPI) queryRemoteKeysOnServer( ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, res *api.QueryKeysResponse, @@ -559,7 +541,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -586,10 +568,10 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( respMu.Unlock() } -func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( +func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -621,11 +603,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( return nil } -func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { // get a list of devices from the user API that actually exist, as // we won't store keys for devices that don't exist - uapidevices := &userapi.QueryDevicesResponse{} - if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + uapidevices := &api.QueryDevicesResponse{} + if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { res.Error = &api.KeyError{ Err: err.Error(), } @@ -643,7 +625,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), @@ -662,7 +644,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) } else { logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) @@ -693,7 +675,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if !a.Cfg.Matrix.IsLocalServerName(serverName) { + if !a.Config.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { @@ -722,30 +704,30 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) + err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } } -func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { if req.UserID == "" { res.Error = &api.KeyError{ Err: "user ID missing", } } if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), } } if counts != nil { @@ -761,7 +743,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: "failed to query existing one-time keys: " + err.Error(), @@ -778,7 +760,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } } // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), diff --git a/keyserver/internal/internal_test.go b/userapi/internal/key_api_test.go similarity index 89% rename from keyserver/internal/internal_test.go rename to userapi/internal/key_api_test.go index 8a2c9c5d9..fc7e7e0df 100644 --- a/keyserver/internal/internal_test.go +++ b/userapi/internal/key_api_test.go @@ -5,23 +5,28 @@ import ( "reflect" "testing" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/storage" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(nil, &config.DatabaseOptions{ + base, _, _ := testrig.Base(nil) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + base.Close() + close() + } } func Test_QueryDeviceMessages(t *testing.T) { @@ -140,8 +145,8 @@ func Test_QueryDeviceMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &internal.KeyInternalAPI{ - DB: db, + a := &internal.UserInternalAPI{ + KeyDatabase: db, } if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go similarity index 94% rename from userapi/internal/api.go rename to userapi/internal/user_api.go index bde6707b7..8a194de09 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/user_api.go @@ -23,6 +23,7 @@ import ( "strconv" "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -32,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" synctypes "github.com/matrix-org/dendrite/syncapi/types" @@ -44,17 +44,19 @@ import ( ) type UserInternalAPI struct { - DB storage.Database - SyncProducer *producers.SyncAPI - Config *config.UserAPI + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase + SyncProducer *producers.SyncAPI + KeyChangeProducer *producers.KeyChange + Config *config.UserAPI DisableTLSValidation bool // AppServices is the list of all registered AS AppServices []config.ApplicationService - KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client - Cfg *config.UserAPI + FedClient fedsenderapi.KeyserverFederationAPI + Updater *DeviceListUpdater } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return fmt.Errorf("a.DB.SetDisplayName: %w", err) } - postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true res.Account = acc @@ -252,6 +254,17 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %s is not local", serverName) } + // If a device ID was specified, check if it already exists and + // avoid sending an empty device list update which would remove + // existing device keys. + isExisting := false + if req.DeviceID != nil && *req.DeviceID != "" { + existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + isExisting = existingDev.ID == *req.DeviceID + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -263,7 +276,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - if req.NoDeviceListUpdate { + if req.NoDeviceListUpdate || isExisting { return nil } // create empty device keys and upload them to trigger device list changes @@ -293,14 +306,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // Ask the keyserver to delete device keys and signatures for those devices - deleteReq := &keyapi.PerformDeleteKeysRequest{ + deleteReq := &api.PerformDeleteKeysRequest{ UserID: req.UserID, } for _, keyID := range req.DeviceIDs { deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } - deleteRes := &keyapi.PerformDeleteKeysResponse{} - if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + deleteRes := &api.PerformDeleteKeysResponse{} + if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { return err } if err := deleteRes.Error; err != nil { @@ -311,17 +324,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ + deviceKeys[i] = api.DeviceKeys{ UserID: userID, DeviceID: did, KeyJSON: nil, } } - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, }, &uploadRes); err != nil { @@ -385,10 +398,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf } if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ + DeviceKeys: []api.DeviceKeys{ { DeviceID: dev.ID, DisplayName: *req.DisplayName, @@ -548,6 +561,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go deleted file mode 100644 index 87ae058c2..000000000 --- a/userapi/inthttp/client.go +++ /dev/null @@ -1,442 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// HTTP paths for the internal HTTP APIs -const ( - InputAccountDataPath = "/userapi/inputAccountData" - - PerformDeviceCreationPath = "/userapi/performDeviceCreation" - PerformAccountCreationPath = "/userapi/performAccountCreation" - PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" - PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" - PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" - PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" - PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" - PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" - PerformKeyBackupPath = "/userapi/performKeyBackup" - PerformPusherSetPath = "/pushserver/performPusherSet" - PerformPusherDeletionPath = "/pushserver/performPusherDeletion" - PerformPushRulesPutPath = "/pushserver/performPushRulesPut" - PerformSetAvatarURLPath = "/userapi/performSetAvatarURL" - PerformSetDisplayNamePath = "/userapi/performSetDisplayName" - PerformForgetThreePIDPath = "/userapi/performForgetThreePID" - PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation" - - QueryKeyBackupPath = "/userapi/queryKeyBackup" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" - QuerySearchProfilesPath = "/userapi/querySearchProfiles" - QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" - QueryPushersPath = "/pushserver/queryPushers" - QueryPushRulesPath = "/pushserver/queryPushRules" - QueryNotificationsPath = "/pushserver/queryNotifications" - QueryNumericLocalpartPath = "/userapi/queryNumericLocalpart" - QueryAccountAvailabilityPath = "/userapi/queryAccountAvailability" - QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" - QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" - QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" -) - -// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewUserAPIClient( - apiURL string, - httpClient *http.Client, -) (api.UserInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewUserAPIClient: httpClient is ") - } - return &httpUserInternalAPI{ - apiURL: apiURL, - httpClient: httpClient, - }, nil -} - -type httpUserInternalAPI struct { - apiURL string - httpClient *http.Client -} - -func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { - return httputil.CallInternalRPCAPI( - "InputAccountData", h.apiURL+InputAccountDataPath, - h.httpClient, ctx, req, res, - ) -} - -func (h *httpUserInternalAPI) PerformAccountCreation( - ctx context.Context, - request *api.PerformAccountCreationRequest, - response *api.PerformAccountCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAccountCreation", h.apiURL+PerformAccountCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPasswordUpdate( - ctx context.Context, - request *api.PerformPasswordUpdateRequest, - response *api.PerformPasswordUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceCreation( - ctx context.Context, - request *api.PerformDeviceCreationRequest, - response *api.PerformDeviceCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceDeletion( - ctx context.Context, - request *api.PerformDeviceDeletionRequest, - response *api.PerformDeviceDeletionResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformLastSeenUpdate( - ctx context.Context, - request *api.PerformLastSeenUpdateRequest, - response *api.PerformLastSeenUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformDeviceUpdate( - ctx context.Context, - request *api.PerformDeviceUpdateRequest, - response *api.PerformDeviceUpdateResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformAccountDeactivation( - ctx context.Context, - request *api.PerformAccountDeactivationRequest, - response *api.PerformAccountDeactivationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformOpenIDTokenCreation( - ctx context.Context, - request *api.PerformOpenIDTokenCreationRequest, - response *api.PerformOpenIDTokenCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryProfile( - ctx context.Context, - request *api.QueryProfileRequest, - response *api.QueryProfileResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryProfile", h.apiURL+QueryProfilePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryDeviceInfos( - ctx context.Context, - request *api.QueryDeviceInfosRequest, - response *api.QueryDeviceInfosResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccessToken( - ctx context.Context, - request *api.QueryAccessTokenRequest, - response *api.QueryAccessTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccessToken", h.apiURL+QueryAccessTokenPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryDevices( - ctx context.Context, - request *api.QueryDevicesRequest, - response *api.QueryDevicesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryDevices", h.apiURL+QueryDevicesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountData( - ctx context.Context, - request *api.QueryAccountDataRequest, - response *api.QueryAccountDataResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountData", h.apiURL+QueryAccountDataPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QuerySearchProfiles( - ctx context.Context, - request *api.QuerySearchProfilesRequest, - response *api.QuerySearchProfilesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryOpenIDToken( - ctx context.Context, - request *api.QueryOpenIDTokenRequest, - response *api.QueryOpenIDTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformKeyBackup( - ctx context.Context, - request *api.PerformKeyBackupRequest, - response *api.PerformKeyBackupResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformKeyBackup", h.apiURL+PerformKeyBackupPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryKeyBackup( - ctx context.Context, - request *api.QueryKeyBackupRequest, - response *api.QueryKeyBackupResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryKeyBackup", h.apiURL+QueryKeyBackupPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryNotifications( - ctx context.Context, - request *api.QueryNotificationsRequest, - response *api.QueryNotificationsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryNotifications", h.apiURL+QueryNotificationsPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPusherSet( - ctx context.Context, - request *api.PerformPusherSetRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPusherSet", h.apiURL+PerformPusherSetPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPusherDeletion( - ctx context.Context, - request *api.PerformPusherDeletionRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryPushers( - ctx context.Context, - request *api.QueryPushersRequest, - response *api.QueryPushersResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPushers", h.apiURL+QueryPushersPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformPushRulesPut( - ctx context.Context, - request *api.PerformPushRulesPutRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryPushRules( - ctx context.Context, - request *api.QueryPushRulesRequest, - response *api.QueryPushRulesResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryPushRules", h.apiURL+QueryPushRulesPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) SetAvatarURL( - ctx context.Context, - request *api.PerformSetAvatarURLRequest, - response *api.PerformSetAvatarURLResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetAvatarURL", h.apiURL+PerformSetAvatarURLPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryNumericLocalpart( - ctx context.Context, - request *api.QueryNumericLocalpartRequest, - response *api.QueryNumericLocalpartResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountAvailability( - ctx context.Context, - request *api.QueryAccountAvailabilityRequest, - response *api.QueryAccountAvailabilityResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryAccountByPassword( - ctx context.Context, - request *api.QueryAccountByPasswordRequest, - response *api.QueryAccountByPasswordResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) SetDisplayName( - ctx context.Context, - request *api.PerformUpdateDisplayNameRequest, - response *api.PerformUpdateDisplayNameResponse, -) error { - return httputil.CallInternalRPCAPI( - "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryLocalpartForThreePID( - ctx context.Context, - request *api.QueryLocalpartForThreePIDRequest, - response *api.QueryLocalpartForThreePIDResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart( - ctx context.Context, - request *api.QueryThreePIDsForLocalpartRequest, - response *api.QueryThreePIDsForLocalpartResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformForgetThreePID( - ctx context.Context, - request *api.PerformForgetThreePIDRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( - ctx context.Context, - request *api.PerformSaveThreePIDAssociationRequest, - response *struct{}, -) error { - return httputil.CallInternalRPCAPI( - "PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go deleted file mode 100644 index 211b1b7a1..000000000 --- a/userapi/inthttp/client_logintoken.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -const ( - PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" - PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" - QueryLoginTokenPath = "/userapi/queryLoginToken" -) - -func (h *httpUserInternalAPI) PerformLoginTokenCreation( - ctx context.Context, - request *api.PerformLoginTokenCreationRequest, - response *api.PerformLoginTokenCreationResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) PerformLoginTokenDeletion( - ctx context.Context, - request *api.PerformLoginTokenDeletionRequest, - response *api.PerformLoginTokenDeletionResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpUserInternalAPI) QueryLoginToken( - ctx context.Context, - request *api.QueryLoginTokenRequest, - response *api.QueryLoginTokenResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryLoginToken", h.apiURL+QueryLoginTokenPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go deleted file mode 100644 index 661fecfae..000000000 --- a/userapi/inthttp/server.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// nolint: gocyclo -func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { - addRoutesLoginToken(internalAPIMux, s) - - internalAPIMux.Handle( - PerformAccountCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation), - ) - - internalAPIMux.Handle( - PerformPasswordUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation), - ) - - internalAPIMux.Handle( - PerformLastSeenUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate), - ) - - internalAPIMux.Handle( - PerformDeviceDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion), - ) - - internalAPIMux.Handle( - PerformAccountDeactivationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation), - ) - - internalAPIMux.Handle( - PerformOpenIDTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation), - ) - - internalAPIMux.Handle( - QueryProfilePath, - httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile), - ) - - internalAPIMux.Handle( - QueryAccessTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken), - ) - - internalAPIMux.Handle( - QueryDevicesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices), - ) - - internalAPIMux.Handle( - QueryAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData), - ) - - internalAPIMux.Handle( - QueryDeviceInfosPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos), - ) - - internalAPIMux.Handle( - QuerySearchProfilesPath, - httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles), - ) - - internalAPIMux.Handle( - QueryOpenIDTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken), - ) - - internalAPIMux.Handle( - InputAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData), - ) - - internalAPIMux.Handle( - QueryKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup), - ) - - internalAPIMux.Handle( - PerformKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup), - ) - - internalAPIMux.Handle( - QueryNotificationsPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications), - ) - - internalAPIMux.Handle( - PerformPusherSetPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet), - ) - - internalAPIMux.Handle( - PerformPusherDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion), - ) - - internalAPIMux.Handle( - QueryPushersPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers), - ) - - internalAPIMux.Handle( - PerformPushRulesPutPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut), - ) - - internalAPIMux.Handle( - QueryPushRulesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules), - ) - - internalAPIMux.Handle( - PerformSetAvatarURLPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), - ) - - internalAPIMux.Handle( - QueryNumericLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), - ) - - internalAPIMux.Handle( - QueryAccountAvailabilityPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability), - ) - - internalAPIMux.Handle( - QueryAccountByPasswordPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword), - ) - - internalAPIMux.Handle( - PerformSetDisplayNamePath, - httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName), - ) - - internalAPIMux.Handle( - QueryLocalpartForThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID), - ) - - internalAPIMux.Handle( - QueryThreePIDsForLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart), - ) - - internalAPIMux.Handle( - PerformForgetThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID), - ) - - internalAPIMux.Handle( - PerformSaveThreePIDAssociationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation), - ) -} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go deleted file mode 100644 index b57348413..000000000 --- a/userapi/inthttp/server_logintoken.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/userapi/api" -) - -// addRoutesLoginToken adds routes for all login token API calls. -func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { - internalAPIMux.Handle( - PerformLoginTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation), - ) - - internalAPIMux.Handle( - PerformLoginTokenDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion), - ) - - internalAPIMux.Handle( - QueryLoginTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken), - ) -} diff --git a/keyserver/producers/keychange.go b/userapi/producers/keychange.go similarity index 94% rename from keyserver/producers/keychange.go rename to userapi/producers/keychange.go index f86c34177..da6cea31a 100644 --- a/keyserver/producers/keychange.go +++ b/userapi/producers/keychange.go @@ -18,9 +18,9 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -28,8 +28,8 @@ import ( // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { Topic string - JetStream nats.JetStreamContext - DB storage.Database + JetStream JetStreamPublisher + DB storage.KeyChangeDatabase } // ProduceKeyChanges creates new change events for each key diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa9856..165de8994 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,13 +19,13 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658f..278378861 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d74..057160374 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_keys_table.go rename to userapi/storage/postgres/cross_signing_keys_table.go index 1022157e8..c0ecbd303 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_sigs_table.go rename to userapi/storage/postgres/cross_signing_sigs_table.go index 4536b7d80..b0117145c 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022012016470000_key_changes.go rename to userapi/storage/postgres/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go similarity index 87% rename from keyserver/storage/postgres/device_keys_table.go rename to userapi/storage/postgres/device_keys_table.go index 2aa11c520..a9203857f 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/userapi/storage/postgres/device_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -92,31 +92,16 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 2dd216189..88f8839c5 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -81,7 +81,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -96,7 +96,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,19 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil +} + +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + return s.InsertDevice(ctx, txn, id, localpart, serverName, accessToken, displayName, ipAddr, userAgent) } // deleteDevice removes a single device by id and user localpart. @@ -271,7 +283,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { @@ -303,7 +315,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7bae..91a34c357 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go similarity index 90% rename from keyserver/storage/postgres/key_changes_table.go rename to userapi/storage/postgres/key_changes_table.go index c0e3429c7..a00494140 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/userapi/storage/postgres/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -66,7 +66,10 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { if err = executeMigration(context.Background(), db); err != nil { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -95,16 +98,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go similarity index 89% rename from keyserver/storage/postgres/one_time_keys_table.go rename to userapi/storage/postgres/one_time_keys_table.go index 2117efcae..972a59147 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -49,7 +49,7 @@ const upsertKeysSQL = "" + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" const selectKeysCountSQL = "" + @@ -84,25 +84,14 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go similarity index 79% rename from keyserver/storage/postgres/stale_device_lists.go rename to userapi/storage/postgres/stale_device_lists.go index d0fe50d00..c823b58c6 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/userapi/storage/postgres/stale_device_lists.go @@ -19,8 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt } func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 01f9e12e8..27d445bf8 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -140,3 +140,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 4bd8a04e7..5cb43507c 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -588,16 +599,41 @@ func (d *Database) CreateDevice( deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil && *deviceID != "" { - returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { - return err - } + _, ok := d.Writer.(*sqlutil.ExclusiveWriter) + if ok { // we're using most likely using SQLite, so do things a little different + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + // No devices yet, only create a new one + if len(devices) == 0 { + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + } + sessionID := devices[0].SessionID + 1 - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) - return err - }) + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + // Create a new device with the session ID incremented + dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID) + return err + }) + } else { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + }) + } } else { // We generate device IDs in a loop in case its already taken. // We cap this at going round 5 times to ensure we don't spin forever @@ -618,7 +654,7 @@ func (d *Database) CreateDevice( } } } - return + return dev, returnErr } // generateDeviceID creates a new device id. Returns an error if failed to generate @@ -849,3 +885,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_keys_table.go rename to userapi/storage/sqlite3/cross_signing_keys_table.go index e103d9883..10721fcc8 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_sigs_table.go rename to userapi/storage/sqlite3/cross_signing_sigs_table.go index 7a153e8fb..2be00c9c1 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go rename to userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go similarity index 88% rename from keyserver/storage/sqlite3/device_keys_table.go rename to userapi/storage/sqlite3/device_keys_table.go index 73768da5b..15e69cc4c 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -86,28 +86,16 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index c5db34bd7..65e17527d 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -65,7 +65,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -80,7 +80,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,36 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil +} + +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + return nil, err + } + dev := &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, serverName), + AccessToken: accessToken, + SessionID: sessionID, + LastSeenTS: createdTimeMS, + LastSeenIP: ipAddr, + UserAgent: userAgent, + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -181,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart @@ -271,7 +301,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } @@ -317,7 +347,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb19..ed2746310 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go similarity index 90% rename from keyserver/storage/sqlite3/key_changes_table.go rename to userapi/storage/sqlite3/key_changes_table.go index 0c844d67a..923bb57eb 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -65,7 +65,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -93,16 +96,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go similarity index 89% rename from keyserver/storage/sqlite3/one_time_keys_table.go rename to userapi/storage/sqlite3/one_time_keys_table.go index 7a923d0e5..a992d399c 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -48,7 +48,7 @@ const upsertKeysSQL = "" + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" const selectKeysCountSQL = "" + @@ -83,25 +83,14 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go similarity index 74% rename from keyserver/storage/sqlite3/stale_device_lists.go rename to userapi/storage/sqlite3/stale_device_lists.go index 1e08b266c..f078fc99f 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -17,10 +17,13 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + type staleDeviceListsStatements struct { db *sql.DB upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime } func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c944..72b3ba49d 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f7063..0f3eeed1b 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e752..0329fb46a 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 29a806e4a..62483595a 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -163,7 +166,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -243,7 +246,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -320,7 +323,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -352,7 +355,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -373,7 +376,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -422,7 +425,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -473,7 +476,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -512,7 +515,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -576,3 +579,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e6..163e3e173 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index e14776cf3..693e73038 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -44,6 +44,7 @@ type AccountsTable interface { type DevicesTable interface { InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error @@ -144,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go new file mode 100644 index 000000000..b9bdafdaa --- /dev/null +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index b088d15cd..969bc5303 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -187,8 +187,8 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Users not active for one/two month", func(t *testing.T) { - mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) - mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, -1, 0)) + mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60)) + mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, 0, -30)) gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -224,9 +224,9 @@ func Test_UserStatistics(t *testing.T) { - Where account creation and last_seen are > 30 days apart */ t.Run("R30Users tests", func(t *testing.T) { - mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) + mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60)) mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now()) - mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0)) + mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, 0, -60)) mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now()) startTime := time.Now().AddDate(0, 0, -2) err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24)) diff --git a/keyserver/types/storage.go b/userapi/types/storage.go similarity index 100% rename from keyserver/types/storage.go rename to userapi/types/storage.go diff --git a/userapi/userapi.go b/userapi/userapi.go index e46a8e76e..826bd7213 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,40 +17,34 @@ package userapi import ( "time" - "github.com/gorilla/mux" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal/pushgateway" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" - "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/util" ) -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { - inthttp.AddRoutes(router, intAPI) -} - -// NewInternalAPI returns a concerete implementation of the internal API. Callers +// NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, - rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, -) api.UserInternalAPI { + base *base.BaseDendrite, + rsAPI rsapi.UserRoomserverAPI, + fedClient fedsenderapi.KeyserverFederationAPI, +) *internal.UserInternalAPI { + cfg := &base.Cfg.UserAPI js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + appServices := base.Cfg.Derived.ApplicationServices - db, err := storage.NewUserAPIDatabase( + pgClient := base.PushGatewayHTTPClient() + + db, err := storage.NewUserDatabase( base, &cfg.AccountDatabase, cfg.Matrix.ServerName, @@ -63,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -72,17 +71,50 @@ func NewInternalAPI( cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), ) + keyChangeProducer := &producers.KeyChange{ + Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + JetStream: js, + DB: keyDB, + } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, Config: cfg, AppServices: appServices, - KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, - Cfg: cfg, + FedClient: fedClient, + } + + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + userAPI.Updater = updater + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() + + dlConsumer := consumers.NewDeviceListUpdateConsumer( + base.ProcessContext, cfg, js, updater, + ) + if err := dlConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start device list consumer") + } + + sigConsumer := consumers.NewSigningKeyUpdateConsumer( + base.ProcessContext, cfg, js, userAPI, + ) + if err := sigConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 25fa75ee2..01e491cb6 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -17,22 +17,20 @@ package userapi_test import ( "context" "fmt" - "net/http" "reflect" + "sync" "testing" "time" - "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/dendrite/userapi/inthttp" - - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/storage" @@ -44,32 +42,71 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +type dummyProducer struct { + callCount sync.Map + t *testing.T +} + +func (d *dummyProducer) PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) { + count, loaded := d.callCount.LoadOrStore(msg.Subject, 1) + if loaded { + c, ok := count.(int) + if !ok { + d.t.Fatalf("unexpected type: %T with value %q", c, c) + } + d.callCount.Store(msg.Subject, c+1) + d.t.Logf("Incrementing call counter for %s", msg.Subject) + } + return &nats.PubAck{}, nil +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, publisher producers.JetStreamPublisher) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } + accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + if publisher == nil { + publisher = &dummyProducer{t: t} + } + + syncProducer := producers.NewSyncAPI(accountDB, publisher, "client_data", "notification_data") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: publisher, Topic: "keychange"} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + KeyDatabase: keyDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -79,19 +116,6 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) - defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) - if err != nil { - t.Fatalf("failed to make account: %s", err) - } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { - t.Fatalf("failed to set avatar url: %s", err) - } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { - t.Fatalf("failed to set display name: %s", err) - } testCases := []struct { req api.QueryProfileRequest @@ -142,18 +166,20 @@ func TestQueryProfile(t *testing.T) { } } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) + defer close() + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { - t.Fatalf("failed to create HTTP client") + t.Fatalf("failed to make account: %s", err) } - runCases(httpAPI, true) - }) - t.Run("Monolith", func(t *testing.T) { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { + t.Fatalf("failed to set avatar url: %s", err) + } + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { + t.Fatalf("failed to set display name: %s", err) + } + runCases(userAPI, false) }) } @@ -164,7 +190,7 @@ func TestQueryProfile(t *testing.T) { func TestPasswordlessLoginFails(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { @@ -190,7 +216,7 @@ func TestLoginToken(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { @@ -240,7 +266,7 @@ func TestLoginToken(t *testing.T) { t.Run("expiredTokenIsNotReturned", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -265,7 +291,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteWorks", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -296,7 +322,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteUnknownIsNoOp", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} var dresp api.PerformLoginTokenDeletionResponse @@ -306,3 +332,346 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + testCases(t, intAPI) + }) +} + +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +} + +func TestDevices(t *testing.T) { + ctx := context.Background() + + dupeAccessToken := util.RandomString(8) + + displayName := "testing" + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + + newDisplayName := "new name" + if tc.inputData.DeviceDisplayName == nil { + updateRes := api.PerformDeviceUpdateResponse{} + updateReq := api.PerformDeviceUpdateRequest{ + RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"), + DeviceID: deviceID, + DisplayName: &newDisplayName, + } + + if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil { + t.Fatal(err) + } + } + + queryDeviceInfosRes := api.QueryDeviceInfosResponse{} + queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}} + if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil { + t.Fatal(err) + } + gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName + if tc.inputData.DeviceDisplayName != nil { + wantDisplayName := *tc.inputData.DeviceDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } else { + wantDisplayName := newDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + publisher := &dummyProducer{t: t} + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, publisher) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + // Set NoDeviceListUpdate to false, to verify we don't send device list updates when + // reusing the same device ID + req.NoDeviceListUpdate = false + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + + publisher.callCount.Range(func(key, value any) bool { + if value != nil { + t.Fatalf("expected publisher to not get called, but got value %d for subject %s", value, key) + } + return true + }) + }) +} diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc7999..31617d8c1 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39bf..08d1371d6 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go new file mode 100644 index 000000000..421852d3f --- /dev/null +++ b/userapi/util/notify_test.go @@ -0,0 +1,119 @@ +package util_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + userUtil "github.com/matrix-org/dendrite/userapi/util" +) + +func TestNotifyUserCountsAsync(t *testing.T) { + alice := test.NewUser(t) + aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) + if err != nil { + t.Error(err) + } + ctx := context.Background() + + // Create a test room, just used to provide events + room := test.NewRoom(t, alice) + dummyEvent := room.Events()[len(room.Events())-1] + + appID := util.RandomString(8) + pushKey := util.RandomString(8) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + receivedRequest := make(chan bool, 1) + // create a test server which responds to our /notify call + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data pushgateway.NotifyRequest + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + notification := data.Notification + // Validate the request + if notification.Counts == nil { + t.Fatal("no unread notification counts in request") + } + if unread := notification.Counts.Unread; unread != 1 { + t.Errorf("expected one unread notification, got %d", unread) + } + + if len(notification.Devices) == 0 { + t.Fatal("expected devices in request") + } + + // We only created one push device, so access it directly + device := notification.Devices[0] + if device.AppID != appID { + t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID) + } + if device.PushKey != pushKey { + t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey) + } + + // Return empty result, otherwise the call is handled as failed + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + close(receivedRequest) + })) + defer srv.Close() + + // Create DB and Dendrite base + connStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + base, _, _ := testrig.Base(nil) + defer base.Close() + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "test", bcrypt.MinCost, 0, 0, "") + if err != nil { + t.Error(err) + } + + // Prepare pusher with our test server URL + if err := db.UpsertPusher(ctx, api.Pusher{ + Kind: api.HTTPKind, + AppID: appID, + PushKey: pushKey, + Data: map[string]interface{}{ + "url": srv.URL, + }, + }, aliceLocalpart, serverName); err != nil { + t.Error(err) + } + + // Insert a dummy event + if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ + Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll), + }); err != nil { + t.Error(err) + } + + // Notify the user about a new notification + if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil { + t.Error(err) + } + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) + +} diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 6f36568c9..21035e045 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -55,7 +55,7 @@ func StartPhoneHomeCollector(startTime time.Time, cfg *config.Dendrite, statsDB serverName: cfg.Global.ServerName, cfg: cfg, db: statsDB, - isMonolith: cfg.IsMonolith, + isMonolith: true, client: &http.Client{ Timeout: time.Second * 30, Transport: http.DefaultTransport, @@ -97,12 +97,10 @@ func (p *phoneHomeStats) collect() { // configuration information p.stats["federation_disabled"] = p.cfg.Global.DisableFederation - p.stats["nats_embedded"] = true - p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory - if len(p.cfg.Global.JetStream.Addresses) > 0 { - p.stats["nats_embedded"] = false - p.stats["nats_in_memory"] = false // probably - } + natsEmbedded := len(p.cfg.Global.JetStream.Addresses) == 0 + p.stats["nats_embedded"] = natsEmbedded + p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory && natsEmbedded + if len(p.cfg.Logging) > 0 { p.stats["log_level"] = p.cfg.Logging[0].Level } else { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go new file mode 100644 index 000000000..5f626b5bc --- /dev/null +++ b/userapi/util/phonehomestats_test.go @@ -0,0 +1,84 @@ +package util + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func TestCollect(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + b, _, _ := testrig.Base(nil) + connStr, closeDB := test.PrepareDBConnectionString(t, dbType) + defer closeDB() + db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "localhost", bcrypt.MinCost, 1000, 1000, "") + if err != nil { + t.Error(err) + } + + receivedRequest := make(chan struct{}, 1) + // create a test server which responds to our call + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + defer r.Body.Close() + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + + // verify the received data matches our expectations + dbEngine, ok := data["database_engine"] + if !ok { + t.Errorf("missing database_engine in JSON request: %+v", data) + } + version, ok := data["version"] + if !ok { + t.Errorf("missing version in JSON request: %+v", data) + } + if version != internal.VersionString() { + t.Errorf("unexpected version: %q, expected %q", version, internal.VersionString()) + } + switch { + case dbType == test.DBTypeSQLite && dbEngine != "SQLite": + t.Errorf("unexpected database_engine: %s", dbEngine) + case dbType == test.DBTypePostgres && dbEngine != "Postgres": + t.Errorf("unexpected database_engine: %s", dbEngine) + } + close(receivedRequest) + })) + defer srv.Close() + + b.Cfg.Global.ReportStats.Endpoint = srv.URL + stats := phoneHomeStats{ + prevData: timestampToRUUsage{}, + serverName: "localhost", + startTime: time.Now(), + cfg: b.Cfg, + db: db, + isMonolith: false, + client: &http.Client{Timeout: time.Second}, + } + + stats.collect() + + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) +}