Merge branch 'main' into mailbox

This commit is contained in:
Devon Hudson 2022-12-02 16:13:30 -07:00
commit 98c7711b84
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
179 changed files with 3555 additions and 1501 deletions

View file

@ -26,22 +26,14 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.18 go-version: 1.18
cache: true
- uses: actions/cache@v2
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-wasm
- name: Install Node - name: Install Node
uses: actions/setup-node@v2 uses: actions/setup-node@v2
with: with:
node-version: 14 node-version: 14
- uses: actions/cache@v2 - uses: actions/cache@v3
with: with:
path: ~/.npm path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
@ -76,7 +68,7 @@ jobs:
# run go test with different go versions # run go test with different go versions
test: test:
timeout-minutes: 5 timeout-minutes: 10
name: Unit tests (Go ${{ matrix.go }}) name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Service containers to run with `container-job` # Service containers to run with `container-job`
@ -102,26 +94,27 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
go: ["1.18", "1.19"] go: ["1.19"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- uses: actions/cache@v3
# manually set up caches, as they otherwise clash with different steps using setup-go with cache=true
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-unit-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-unit-
- name: Set up gotestfmt - name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2 uses: gotesttools/gotestfmt-action@v2
with: with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting. # Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test -json -v ./... 2>&1 | gotestfmt - run: go test -json -v ./... 2>&1 | gotestfmt
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
@ -146,17 +139,17 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
~/go/pkg/mod ~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- env: - env:
GOOS: ${{ matrix.goos }} GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
@ -180,16 +173,16 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
~/go/pkg/mod ~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- env: - env:
GOOS: ${{ matrix.goos }} GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
@ -209,6 +202,66 @@ jobs:
with: with:
jobs: ${{ toJSON(needs) }} jobs: ${{ toJSON(needs) }}
# run go test with different go versions
integration:
timeout-minutes: 20
needs: initial-tests-done
name: Integration tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
fail-fast: false
matrix:
go: ["1.19"]
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2
with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-race-
- run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
flags: unittests
# run database upgrade tests # run database upgrade tests
upgrade_test: upgrade_test:
name: Upgrade tests name: Upgrade tests
@ -221,18 +274,13 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 cache: true
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade - name: Test upgrade (PostgreSQL)
run: ./dendrite-upgrade-tests --head . run: ./dendrite-upgrade-tests --head .
- name: Test upgrade (SQLite)
run: ./dendrite-upgrade-tests --sqlite --head .
# run database upgrade tests, skipping over one version # run database upgrade tests, skipping over one version
upgrade_test_direct: upgrade_test_direct:
@ -246,17 +294,12 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 cache: true
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade - name: Test upgrade (PostgreSQL)
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
- name: Test upgrade (SQLite)
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
# run Sytest in different variations # run Sytest in different variations
@ -291,6 +334,8 @@ jobs:
image: matrixdotorg/sytest-dendrite:latest image: matrixdotorg/sytest-dendrite:latest
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build
- /root/.cache/go-mod:/gopath/pkg/mod
env: env:
POSTGRES: ${{ matrix.postgres && 1}} POSTGRES: ${{ matrix.postgres && 1}}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
@ -298,6 +343,14 @@ jobs:
CGO_ENABLED: ${{ matrix.cgo && 1 }} CGO_ENABLED: ${{ matrix.cgo && 1 }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
/gopath/pkg/mod
key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-sytest-
- name: Run Sytest - name: Run Sytest
run: /bootstrap.sh dendrite run: /bootstrap.sh dendrite
working-directory: /src working-directory: /src
@ -332,12 +385,14 @@ jobs:
matrix: matrix:
include: include:
- label: SQLite native - label: SQLite native
cgo: 0
- label: SQLite Cgo - label: SQLite Cgo
cgo: 1 cgo: 1
- label: SQLite native, full HTTP APIs - label: SQLite native, full HTTP APIs
api: full-http api: full-http
cgo: 0
- label: SQLite Cgo, full HTTP APIs - label: SQLite Cgo, full HTTP APIs
api: full-http api: full-http
@ -345,10 +400,12 @@ jobs:
- label: PostgreSQL - label: PostgreSQL
postgres: Postgres postgres: Postgres
cgo: 0
- label: PostgreSQL, full HTTP APIs - label: PostgreSQL, full HTTP APIs
postgres: Postgres postgres: Postgres
api: full-http api: full-http
cgo: 0
steps: steps:
# Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement.
# See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path
@ -356,14 +413,12 @@ jobs:
run: | run: |
echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH
echo "~/go/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH
- name: "Install Complement Dependencies" - name: "Install Complement Dependencies"
# We don't need to install Go because it is included on the Ubuntu 20.04 image: # We don't need to install Go because it is included on the Ubuntu 20.04 image:
# See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64
run: | run: |
sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev
go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest
- name: Run actions/checkout@v3 for dendrite - name: Run actions/checkout@v3 for dendrite
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -389,12 +444,10 @@ jobs:
if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
continue continue
fi fi
(wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
done done
# Build initial Dendrite image # Build initial Dendrite image
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
working-directory: dendrite working-directory: dendrite
env: env:
DOCKER_BUILDKIT: 1 DOCKER_BUILDKIT: 1
@ -406,9 +459,8 @@ jobs:
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
env: env:
COMPLEMENT_BASE_IMAGE: complement-dendrite:latest COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
CGO_ENABLED: ${{ matrix.cgo && 1 }}
working-directory: complement working-directory: complement
integration-tests-done: integration-tests-done:
@ -420,6 +472,7 @@ jobs:
upgrade_test_direct, upgrade_test_direct,
sytest, sytest,
complement, complement,
integration
] ]
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped

View file

@ -32,7 +32,7 @@ jobs:
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
run: | run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
[ ${BRANCH} == "main" ] && BRANCH="" [ ${BRANCH} == "main" ] && BRANCH=""
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
@ -68,18 +68,6 @@ jobs:
${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }}
format: "sarif"
output: "trivy-results.sarif"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: "trivy-results.sarif"
- name: Build release monolith image - name: Build release monolith image
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
id: docker_build_monolith_release id: docker_build_monolith_release
@ -98,6 +86,18 @@ jobs:
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }}
format: "sarif"
output: "trivy-results.sarif"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: "trivy-results.sarif"
polylith: polylith:
name: Polylith image name: Polylith image
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -112,7 +112,7 @@ jobs:
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
run: | run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
[ ${BRANCH} == "main" ] && BRANCH="" [ ${BRANCH} == "main" ] && BRANCH=""
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
@ -148,18 +148,6 @@ jobs:
${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }}
format: "sarif"
output: "trivy-results.sarif"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: "trivy-results.sarif"
- name: Build release polylith image - name: Build release polylith image
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
id: docker_build_polylith_release id: docker_build_polylith_release
@ -178,6 +166,18 @@ jobs:
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }}
format: "sarif"
output: "trivy-results.sarif"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: "trivy-results.sarif"
demo-pinecone: demo-pinecone:
name: Pinecone demo image name: Pinecone demo image
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -191,7 +191,7 @@ jobs:
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
run: | run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
[ ${BRANCH} == "main" ] && BRANCH="" [ ${BRANCH} == "main" ] && BRANCH=""
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
@ -258,7 +258,7 @@ jobs:
if: github.event_name == 'release' # Only for GitHub releases if: github.event_name == 'release' # Only for GitHub releases
run: | run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
[ ${BRANCH} == "main" ] && BRANCH="" [ ${BRANCH} == "main" ] && BRANCH=""
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV echo "BRANCH=${BRANCH}" >> $GITHUB_ENV

View file

@ -10,72 +10,9 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
# run go test with different go versions
test:
timeout-minutes: 20
name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
fail-fast: false
matrix:
go: ["1.18", "1.19"]
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-race-
- run: go test -race ./...
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
# Dummy step to gate other tests on without repeating the whole list
initial-tests-done:
name: Initial tests passed
needs: [test]
runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps:
- name: Check initial tests passed
uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
# run Sytest in different variations # run Sytest in different variations
sytest: sytest:
timeout-minutes: 60 timeout-minutes: 60
needs: initial-tests-done
name: "Sytest (${{ matrix.label }})" name: "Sytest (${{ matrix.label }})"
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
@ -97,13 +34,23 @@ jobs:
image: matrixdotorg/sytest-dendrite:latest image: matrixdotorg/sytest-dendrite:latest
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build
- /root/.cache/go-mod:/gopath/pkg/mod
env: env:
POSTGRES: ${{ matrix.postgres && 1}} POSTGRES: ${{ matrix.postgres && 1}}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
SYTEST_BRANCH: ${{ github.head_ref }} SYTEST_BRANCH: ${{ github.head_ref }}
RACE_DETECTION: 1 RACE_DETECTION: 1
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
/gopath/pkg/mod
key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-sytest-
- name: Run Sytest - name: Run Sytest
run: /bootstrap.sh dendrite run: /bootstrap.sh dendrite
working-directory: /src working-directory: /src

View file

@ -1,5 +1,45 @@
# Changelog # Changelog
## Dendrite 0.10.8 (2022-11-29)
### Features
* The built-in NATS Server has been updated to version 2.9.8
* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment)
### Fixes
* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly
* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory
* The sync API will no longer filter out the user's own membership when using lazy-loading
* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed
* A panic in the federation API where the server list could go out of bounds has been fixed
* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests
* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet
* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely
* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely
* Missing state key NIDs will now be allocated on request rather than returning an error
* Guest access is now correctly denied on a number of endpoints
* Presence information will now be correctly sent for new private chats
* A number of unspecced fields have been removed from outbound `/send` transactions
## Dendrite 0.10.7 (2022-11-04)
### Features
* Dendrite will now use a native SQLite port when building with `CGO_ENABLED=0`
* A number of `thirdparty` endpoints have been added, improving support for appservices
### Fixes
* The `"state"` section of the `/sync` response is no longer limited, so state events should not be dropped unexpectedly
* The deduplication of the `"timeline"` and `"state"` sections in `/sync` is now performed after applying history visibility, so state events should not be dropped unexpectedly
* The `prev_batch` token returned by `/sync` is now calculated after applying history visibility, so that the pagination boundaries are correct
* The room summary membership counts in `/sync` should now be calculated properly in more cases
* A false membership leave event should no longer be sent down `/sync` as a result of retiring an accepted invite (contributed by [tak-hntlabs](https://github.com/tak-hntlabs))
* Presence updates are now only sent to other servers for which the user shares rooms
* A bug which could cause a panic when converting events into the `ClientEvent` format has been fixed
## Dendrite 0.10.6 (2022-11-01) ## Dendrite 0.10.6 (2022-11-01)
### Features ### Features

View file

@ -24,7 +24,7 @@ RUN --mount=target=. \
go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/... go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/...
# #
# The dendrite base image; mainly creates a user and switches to it # The dendrite base image
# #
FROM alpine:latest AS dendrite-base FROM alpine:latest AS dendrite-base
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
@ -32,8 +32,6 @@ LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.licenses="Apache-2.0"
LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/"
LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C."
RUN addgroup dendrite && adduser dendrite -G dendrite -u 1337 -D
USER dendrite
# #
# Builds the polylith image and only contains the polylith binary # Builds the polylith image and only contains the polylith binary

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out. // events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices { for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount( func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI, userAPI userapi.AppserviceUserAPI,
as config.ApplicationService, as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken, AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart, DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart,

View file

@ -0,0 +1,224 @@
package appservice_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/appservice/inthttp"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/test/testrig"
)
func TestAppserviceInternalAPI(t *testing.T) {
// Set expected results
existingProtocol := "irc"
wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}}
wantProtocolResult := map[string]api.ASProtocolResponse{
existingProtocol: wantProtocolResponse,
}
// create a dummy AS url, handling some cases
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "location"):
// Check if we've got an existing protocol, if so, return a proper response.
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
}
if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
case strings.Contains(r.URL.Path, "user"):
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
}
if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
case strings.Contains(r.URL.Path, "protocol"):
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
}
if err := json.NewEncoder(w).Encode(nil); err != nil {
t.Fatalf("failed to encode response: %s", err)
}
return
default:
t.Logf("hit location: %s", r.URL.Path)
}
}))
// TODO: use test.WithAllDatabases
// only one DBType, since appservice.AddInternalRoutes complains about multiple prometheus counters added
base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite)
defer closeBase()
// Create a dummy application service
base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{
{
ID: "someID",
URL: srv.URL,
ASToken: "",
HSToken: "",
SenderLocalpart: "senderLocalPart",
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
"users": {{RegexpObject: regexp.MustCompile("as-.*")}},
"aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}},
},
Protocols: []string{existingProtocol},
},
}
// Create required internal APIs
rsAPI := roomserver.NewInternalAPI(base)
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI)
// The test cases to run
runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) {
t.Run("UserIDExists", func(t *testing.T) {
testUserIDExists(t, testAPI, "@as-testing:test", true)
testUserIDExists(t, testAPI, "@as1-testing:test", false)
})
t.Run("AliasExists", func(t *testing.T) {
testAliasExists(t, testAPI, "@asroom-testing:test", true)
testAliasExists(t, testAPI, "@asroom1-testing:test", false)
})
t.Run("Locations", func(t *testing.T) {
testLocations(t, testAPI, existingProtocol, wantLocationResponse)
testLocations(t, testAPI, "abc", nil)
})
t.Run("User", func(t *testing.T) {
testUser(t, testAPI, existingProtocol, wantUserResponse)
testUser(t, testAPI, "abc", nil)
})
t.Run("Protocols", func(t *testing.T) {
testProtocol(t, testAPI, existingProtocol, wantProtocolResult)
testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache
testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols
testProtocol(t, testAPI, "abc", nil)
})
}
// Finally execute the tests
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
appservice.AddInternalRoutes(router, asAPI)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{})
if err != nil {
t.Fatalf("failed to create HTTP client: %s", err)
}
runCases(t, asHTTPApi)
})
t.Run("Monolith", func(t *testing.T) {
runCases(t, asAPI)
})
}
func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) {
ctx := context.Background()
userResp := &api.UserIDExistsResponse{}
if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{
UserID: userID,
}, userResp); err != nil {
t.Errorf("failed to get userID: %s", err)
}
if userResp.UserIDExists != wantExists {
t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists)
}
}
func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) {
ctx := context.Background()
aliasResp := &api.RoomAliasExistsResponse{}
if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{
Alias: alias,
}, aliasResp); err != nil {
t.Errorf("failed to get alias: %s", err)
}
if aliasResp.AliasExists != wantExists {
t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists)
}
}
func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) {
ctx := context.Background()
locationResp := &api.LocationResponse{}
if err := asAPI.Locations(ctx, &api.LocationRequest{
Protocol: proto,
}, locationResp); err != nil {
t.Errorf("failed to get locations: %s", err)
}
if !reflect.DeepEqual(locationResp.Locations, wantResult) {
t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult)
}
}
func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) {
ctx := context.Background()
userResp := &api.UserResponse{}
if err := asAPI.User(ctx, &api.UserRequest{
Protocol: proto,
}, userResp); err != nil {
t.Errorf("failed to get user: %s", err)
}
if !reflect.DeepEqual(userResp.Users, wantResult) {
t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult)
}
}
func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) {
ctx := context.Background()
protoResp := &api.ProtocolResponse{}
if err := asAPI.Protocols(ctx, &api.ProtocolRequest{
Protocol: proto,
}, protoResp); err != nil {
t.Errorf("failed to get Protocols: %s", err)
}
if !reflect.DeepEqual(protoResp.Protocols, wantResult) {
t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult)
}
}

View file

@ -40,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
@ -58,6 +59,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/matrix-org/pinecone/types" "github.com/matrix-org/pinecone/types"
@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() {
m.logger.SetOutput(BindLogger{}) m.logger.SetOutput(BindLogger{})
logrus.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{})
pineconeEventChannel := make(chan pineconeEvents.Event)
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
m.PineconeRouter.EnableHopLimiting()
m.PineconeRouter.EnableWakeupBroadcasts()
m.PineconeRouter.Subscribe(pineconeEventChannel)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil)
@ -329,6 +336,7 @@ func (m *DendriteMonolith) Start() {
} }
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
base.ConfigureAdminEndpoints()
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
federation := conn.CreateFederationClient(base, m.PineconeQUIC) federation := conn.CreateFederationClient(base, m.PineconeQUIC)
@ -375,6 +383,8 @@ func (m *DendriteMonolith) Start() {
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux)
httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux)
httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux)
httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux)
httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler)
pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() pMux := mux.NewRouter().SkipClean(true).UseEncodedPath()
@ -423,6 +433,34 @@ func (m *DendriteMonolith) Start() {
m.logger.Fatal(err) m.logger.Fatal(err)
} }
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
} }
func (m *DendriteMonolith) Stop() { func (m *DendriteMonolith) Stop() {

View file

@ -150,6 +150,7 @@ func (m *DendriteMonolith) Start() {
} }
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
base.ConfigureAdminEndpoints()
m.processContext = base.ProcessContext m.processContext = base.ProcessContext
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
@ -196,6 +197,8 @@ func (m *DendriteMonolith) Start() {
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux)
httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux)
httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux)
httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux)
yggRouter := mux.NewRouter() yggRouter := mux.NewRouter()
yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux)

View file

@ -10,12 +10,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -28,12 +28,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -66,8 +67,10 @@ func TestLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if err != nil { if err != nil {
@ -144,8 +147,10 @@ func TestBadLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if errRes == nil { if errRes == nil {

View file

@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := strings.ToLower(r.Username()) username := r.Username()
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."), JSON: jsonerror.BadJSON("A password must be supplied."),
} }
} }
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername(err.Error()), JSON: jsonerror.InvalidUsername(err.Error()),
} }
} }
if !t.Config.Matrix.IsLocalServerName(domain) {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername("The server name is not known."),
}
}
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: strings.ToLower(localpart),
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
if !res.Exists { if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows

View file

@ -47,8 +47,10 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a
func setup() *UserInteractive { func setup() *UserInteractive {
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
return NewUserInteractive(&fakeAccountDatabase{}, cfg) return NewUserInteractive(&fakeAccountDatabase{}, cfg)
} }

View file

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"] localpart, ok := vars["localpart"]
if !ok { if !ok {
return util.JSONResponse{ return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
Password: request.Password, Password: request.Password,
LogoutDevices: true, LogoutDevices: true,
} }

View file

@ -477,7 +477,7 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -77,7 +77,7 @@ func DirectoryRoom(
// If we don't know it locally, do a federation query. // If we don't know it locally, do a federation query.
// But don't send the query to ourselves. // But don't send the query to ourselves.
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias)
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.

View file

@ -74,7 +74,7 @@ func GetPostPublicRooms(
serverName := gomatrixserverlib.ServerName(request.Server) serverName := gomatrixserverlib.ServerName(request.Server)
if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) {
res, err := federation.GetPublicRoomsFiltered( res, err := federation.GetPublicRoomsFiltered(
req.Context(), serverName, req.Context(), cfg.Matrix.ServerName, serverName,
int(request.Limit), request.Since, int(request.Limit), request.Since,
request.Filter.SearchTerms, false, request.Filter.SearchTerms, false,
"", "",

View file

@ -100,6 +100,7 @@ func completeAuth(
DeviceID: login.DeviceID, DeviceID: login.DeviceID,
AccessToken: token, AccessToken: token,
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
IPAddr: ipAddr, IPAddr: ipAddr,
UserAgent: userAgent, UserAgent: userAgent,
}, &performRes) }, &performRes)

View file

@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
device.UserDomain(),
serverName, serverName,
serverName, serverName,
nil, nil,
@ -322,7 +323,12 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns

View file

@ -40,13 +40,14 @@ func GetNotifications(
} }
var queryRes userapi.QueryNotificationsResponse var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
From: req.URL.Query().Get("from"), From: req.URL.Query().Get("from"),
Limit: int(limit), Limit: int(limit),
Only: req.URL.Query().Get("only"), Only: req.URL.Query().Get("only"),

View file

@ -86,7 +86,7 @@ func Password(
} }
// Get the local part. // Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -95,6 +95,7 @@ func Password(
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Password: r.NewPassword, Password: r.NewPassword,
} }
passwordRes := &api.PerformPasswordUpdateResponse{} passwordRes := &api.PerformPasswordUpdateResponse{}
@ -123,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID, SessionID: device.SessionID,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -284,7 +284,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -298,7 +298,7 @@ func updateProfile(
return jsonerror.InternalServerError(), e return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
@ -321,7 +321,7 @@ func getProfile(
} }
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {
@ -349,6 +349,7 @@ func getProfile(
func buildMembershipEvents( func buildMembershipEvents(
ctx context.Context, ctx context.Context,
device *userapi.Device,
roomIDs []string, roomIDs []string,
newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI,
evTime time.Time, rsAPI api.ClientRoomserverAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI,
@ -380,7 +381,12 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var queryRes userapi.QueryPushersResponse var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
} }
body.Localpart = localpart body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil { if err != nil {

View file

@ -123,8 +123,13 @@ func SendRedaction(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return jsonerror.InternalServerError()
}
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse var queryRes roomserverAPI.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -132,7 +137,7 @@ func SendRedaction(
} }
} }
domain := device.UserDomain() domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -213,6 +213,7 @@ type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
ServerName gomatrixserverlib.ServerName `json:"-"`
Admin bool `json:"admin"` Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -550,6 +551,12 @@ func Register(
} }
var r registerRequest var r registerRequest
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
r.ServerName = v.ServerName
} else {
r.ServerName = cfg.Matrix.ServerName
}
sessionID := gjson.GetBytes(reqBody, "auth.session").String() sessionID := gjson.GetBytes(reqBody, "auth.session").String()
if sessionID == "" { if sessionID == "" {
// Generate a new, random session ID // Generate a new, random session ID
@ -559,6 +566,7 @@ func Register(
// Some of these might end up being overwritten if the // Some of these might end up being overwritten if the
// values are specified again in the request body. // values are specified again in the request body.
r.Username = data.Username r.Username = data.Username
r.ServerName = data.ServerName
r.Password = data.Password r.Password = data.Password
r.DeviceID = data.DeviceID r.DeviceID = data.DeviceID
r.InitialDisplayName = data.InitialDisplayName r.InitialDisplayName = data.InitialDisplayName
@ -570,11 +578,13 @@ func Register(
JSON: response, JSON: response,
} }
} }
} }
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
@ -588,12 +598,15 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{} nreq := &userapi.QueryNumericLocalpartRequest{
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { ServerName: r.ServerName,
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10) r.Username = strconv.FormatInt(nres.ID, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access
@ -606,7 +619,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -619,7 +632,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -643,16 +656,25 @@ func handleGuestRegistration(
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled { registrationEnabled := !cfg.RegistrationDisabled
guestsEnabled := !cfg.GuestsDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, guestsEnabled = v.RegistrationAllowed()
}
if !registrationEnabled || !guestsEnabled {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Guest registration is disabled on %q", r.ServerName),
),
} }
} }
var res userapi.PerformAccountCreationResponse var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest, AccountType: userapi.AccountTypeGuest,
ServerName: r.ServerName,
}, &res) }, &res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -676,6 +698,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr, IPAddr: req.RemoteAddr,
@ -728,10 +751,16 @@ func handleRegistrationFlow(
) )
} }
if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { registrationEnabled := !cfg.RegistrationDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, _ = v.RegistrationAllowed()
}
if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is disabled on %q", r.ServerName),
),
} }
} }
@ -819,8 +848,9 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
userapi.AccountTypeAppService,
) )
} }
@ -838,8 +868,9 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
userapi.AccountTypeUser,
) )
} }
sessions.addParams(sessionID, r) sessions.addParams(sessionID, r)
@ -861,7 +892,8 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string, username string, serverName gomatrixserverlib.ServerName,
password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
@ -883,6 +915,7 @@ func completeRegistration(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
ServerName: serverName,
Password: password, Password: password,
AccountType: accType, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
@ -926,6 +959,7 @@ func completeRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: username, Localpart: username,
ServerName: serverName,
AccessToken: token, AccessToken: token,
DeviceDisplayName: displayName, DeviceDisplayName: displayName,
DeviceID: deviceID, DeviceID: deviceID,
@ -1019,13 +1053,31 @@ func RegisterAvailable(
// Squash username to all lowercase letters // Squash username to all lowercase letters
username = strings.ToLower(username) username = strings.ToLower(username)
domain := cfg.Matrix.ServerName
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
domain = v.ServerName
}
if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil {
username, domain = u, l
}
for _, v := range cfg.Matrix.VirtualHosts {
if v.ServerName == domain && !v.AllowRegistration {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)),
),
}
}
}
if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { if err := validateUsername(username, domain); err != nil {
return *err return *err
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) userID := userutil.MakeUserID(username, domain)
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.OwnsNamespaceCoveringUserId(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
return util.JSONResponse{ return util.JSONResponse{
@ -1038,6 +1090,7 @@ func RegisterAvailable(
res := &userapi.QueryAccountAvailabilityResponse{} res := &userapi.QueryAccountAvailabilityResponse{}
err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: username, Localpart: username,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -1094,5 +1147,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -157,7 +157,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI) return AdminResetPassword(req, cfg, device, userAPI)
}), }),
@ -252,7 +252,7 @@ func Setup(
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, userAPI, vars["roomIDOrAlias"], req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
@ -274,7 +274,7 @@ func Setup(
v3mux.Handle("/joined_rooms", v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI) return GetJoinedRooms(req, device, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/join", v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -288,7 +288,7 @@ func Setup(
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, userAPI, vars["roomID"], req, device, rsAPI, userAPI, vars["roomID"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -302,7 +302,7 @@ func Setup(
return LeaveRoomByID( return LeaveRoomByID(
req, device, rsAPI, vars["roomID"], req, device, rsAPI, vars["roomID"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unpeek", v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -361,7 +361,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -372,7 +372,7 @@ func Setup(
txnID := vars["txnID"] txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -381,7 +381,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -400,7 +400,7 @@ func Setup(
eventType := strings.TrimSuffix(vars["type"], "/") eventType := strings.TrimSuffix(vars["type"], "/")
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -409,7 +409,7 @@ func Setup(
} }
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -420,7 +420,7 @@ func Setup(
emptyString := "" emptyString := ""
eventType := strings.TrimSuffix(vars["eventType"], "/") eventType := strings.TrimSuffix(vars["eventType"], "/")
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
@ -431,7 +431,7 @@ func Setup(
} }
stateKey := vars["stateKey"] stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
@ -575,7 +575,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// This is only here because sytest refers to /unstable for this endpoint // This is only here because sytest refers to /unstable for this endpoint
@ -589,7 +589,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
@ -598,7 +598,7 @@ func Setup(
return *r return *r
} }
return Whoami(req, device) return Whoami(req, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
@ -830,7 +830,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
@ -871,7 +871,7 @@ func Setup(
v3mux.Handle("/thirdparty/protocols", v3mux.Handle("/thirdparty/protocols",
httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Protocols(req, asAPI, device, "") return Protocols(req, asAPI, device, "")
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/protocol/{protocolID}", v3mux.Handle("/thirdparty/protocol/{protocolID}",
@ -881,7 +881,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return Protocols(req, asAPI, device, vars["protocolID"]) return Protocols(req, asAPI, device, vars["protocolID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/user/{protocolID}", v3mux.Handle("/thirdparty/user/{protocolID}",
@ -891,13 +891,13 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) return User(req, asAPI, device, vars["protocolID"], req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/user", v3mux.Handle("/thirdparty/user",
httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return User(req, asAPI, device, "", req.URL.Query()) return User(req, asAPI, device, "", req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/location/{protocolID}", v3mux.Handle("/thirdparty/location/{protocolID}",
@ -907,13 +907,13 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) return Location(req, asAPI, device, vars["protocolID"], req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/location", v3mux.Handle("/thirdparty/location",
httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Location(req, asAPI, device, "", req.URL.Query()) return Location(req, asAPI, device, "", req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/initialSync", v3mux.Handle("/rooms/{roomID}/initialSync",
@ -1054,7 +1054,7 @@ func Setup(
v3mux.Handle("/devices", v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device) return GetDevicesByLocalpart(req, userAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1064,7 +1064,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDeviceByID(req, userAPI, device, vars["deviceID"]) return GetDeviceByID(req, userAPI, device, vars["deviceID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1074,7 +1074,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1116,21 +1116,21 @@ func Setup(
// Stub implementations for sytest // Stub implementations for sytest
v3mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{}, "chunk": []interface{}{},
"start": "", "start": "",
"end": "", "end": "",
}} }}
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/initialSync", v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "", "end": "",
}} }}
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
@ -1169,7 +1169,7 @@ func Setup(
return *r return *r
} }
return GetCapabilities(req, rsAPI) return GetCapabilities(req, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// Key Backup Versions (Metadata) // Key Backup Versions (Metadata)
@ -1350,7 +1350,7 @@ func Setup(
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceSignatures(req, keyAPI, device) return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
}) }, httputil.WithAllowGuests())
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
@ -1362,22 +1362,22 @@ func Setup(
v3mux.Handle("/keys/upload/{deviceID}", v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload", v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device) return QueryKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/claim", v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI) return ClaimKeys(req, keyAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -186,6 +186,7 @@ func SendEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion), e.Headered(verRes.RoomVersion),
}, },
device.UserDomain(),
domain, domain,
domain, domain,
txnAndSessionID, txnAndSessionID,
@ -275,8 +276,14 @@ func generateSendEvent(
return nil, &resErr return nil, &resErr
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -231,6 +231,7 @@ func SendServerNotice(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion), e.Headered(roomVersion),
}, },
device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
txnAndSessionID, txnAndSessionID,
@ -286,6 +287,7 @@ func getSenderDevice(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
@ -296,6 +298,7 @@ func getSenderDevice(
avatarRes := &userapi.PerformSetAvatarURLResponse{} avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, avatarRes); err != nil { }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
@ -308,6 +311,7 @@ func getSenderDevice(
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DisplayName: cfg.Matrix.ServerNotices.DisplayName, DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil { }, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
@ -353,6 +357,7 @@ func getSenderDevice(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token, AccessToken: token,
NoDeviceListUpdate: true, NoDeviceListUpdate: true,

View file

@ -36,11 +36,17 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if !resp.Exists { if !resp.Exists {
if protocol != "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The protocol is unknown."), JSON: jsonerror.NotFound("The protocol is unknown."),
} }
} }
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
if protocol != "" { if protocol != "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -136,7 +136,7 @@ func CheckAndSave3PIDAssociation(
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -145,6 +145,7 @@ func CheckAndSave3PIDAssociation(
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
ThreePID: address, ThreePID: address,
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Medium: medium, Medium: medium,
}, &struct{}{}); err != nil { }, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -170,6 +171,7 @@ func GetAssociated3PIDs(
res := &api.QueryThreePIDsForLocalpartResponse{} res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")

View file

@ -106,7 +106,7 @@ knownUsersLoop:
continue continue
} }
// TODO: We should probably cache/store this // TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {

View file

@ -359,8 +359,13 @@ func emit3PIDInviteEvent(
return err return err
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return err
}
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
@ -371,6 +376,7 @@ func emit3PIDInviteEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(queryRes.RoomVersion), event.Headered(queryRes.RoomVersion),
}, },
device.UserDomain(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,

View file

@ -30,7 +30,9 @@ var (
// TestGoodUserID checks that correct localpart is returned for a valid user ID. // TestGoodUserID checks that correct localpart is returned for a valid user ID.
func TestGoodUserID(t *testing.T) { func TestGoodUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(goodUserID, cfg) lp, _, err := ParseUsernameParam(goodUserID, cfg)
@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) {
// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart.
func TestWithLocalpartOnly(t *testing.T) { func TestWithLocalpartOnly(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(localpart, cfg) lp, _, err := ParseUsernameParam(localpart, cfg)
@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) {
// TestIncorrectDomain checks for error when there's server name mismatch. // TestIncorrectDomain checks for error when there's server name mismatch.
func TestIncorrectDomain(t *testing.T) { func TestIncorrectDomain(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: invalidServerName, ServerName: invalidServerName,
},
} }
_, _, err := ParseUsernameParam(goodUserID, cfg) _, _, err := ParseUsernameParam(goodUserID, cfg)
@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) {
// TestBadUserID checks that ParseUsernameParam fails for invalid user ID // TestBadUserID checks that ParseUsernameParam fails for invalid user ID
func TestBadUserID(t *testing.T) { func TestBadUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
_, _, err := ParseUsernameParam(badUserID, cfg) _, _, err := ParseUsernameParam(badUserID, cfg)

View file

@ -177,7 +177,7 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a
defer regResp.Body.Close() // nolint: errcheck defer regResp.Body.Close() // nolint: errcheck
if regResp.StatusCode < 200 || regResp.StatusCode >= 300 { if regResp.StatusCode < 200 || regResp.StatusCode >= 300 {
body, _ = io.ReadAll(regResp.Body) body, _ = io.ReadAll(regResp.Body)
return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) return "", fmt.Errorf("got HTTP %d error from server: %s", regResp.StatusCode, string(body))
} }
r, err := io.ReadAll(regResp.Body) r, err := io.ReadAll(regResp.Body)
if err != nil { if err != nil {

View file

@ -101,9 +101,7 @@ func CreateFederationClient(
base *base.BaseDendrite, s *pineconeSessions.Sessions, base *base.BaseDendrite, s *pineconeSessions.Sessions,
) *gomatrixserverlib.FederationClient { ) *gomatrixserverlib.FederationClient {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.KeyID,
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(createTransport(s)), gomatrixserverlib.WithTransport(createTransport(s)),
) )
} }

View file

@ -37,6 +37,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
@ -51,6 +52,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -153,9 +155,15 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
base.ConfigureAdminEndpoints()
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
pineconeEventChannel := make(chan pineconeEvents.Event)
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
pRouter.EnableHopLimiting()
pRouter.EnableWakeupBroadcasts()
pRouter.Subscribe(pineconeEventChannel)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
@ -241,6 +249,8 @@ func main() {
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux)
httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux)
httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux)
httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux)
httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
c, err := wsUpgrader.Upgrade(w, r, nil) c, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@ -293,5 +303,33 @@ func main() {
logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter))
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
base.WaitForShutdown() base.WaitForShutdown()
} }

View file

@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
for _, k := range p.r.Peers() { for _, k := range p.r.Peers() {
list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{}
} }
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.r.PublicKey().String()), list,
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers map[gomatrixserverlib.ServerName]struct{}, homeservers map[gomatrixserverlib.ServerName]struct{},
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -144,6 +144,7 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
base.ConfigureAdminEndpoints()
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen) ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen)
@ -198,6 +199,8 @@ func main() {
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux)
httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux)
httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux)
httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux)
embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo") embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo")
yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()

View file

@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient(
}, },
) )
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(tr), gomatrixserverlib.WithTransport(tr),
) )
} }

View file

@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider(
} }
func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.node.DerivedServerName()),
p.node.KnownNodes(),
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers []gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName,
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -7,6 +7,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -38,6 +39,7 @@ var (
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
) )
@ -49,7 +51,7 @@ const HEAD = "HEAD"
// due to the error: // due to the error:
// When using COPY with more than one source file, the destination must be a directory and end with a / // When using COPY with more than one source file, the destination must be a directory and end with a /
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead. // We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
const Dockerfile = `FROM golang:1.18-stretch as build const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql RUN apt-get update && apt-get install -y postgresql
WORKDIR /build WORKDIR /build
@ -60,6 +62,7 @@ COPY . .
RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/dendrite-monolith-server
RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-keys
RUN go build ./cmd/generate-config RUN go build ./cmd/generate-config
RUN go build ./cmd/create-account
RUN ./generate-config --ci > dendrite.yaml RUN ./generate-config --ci > dendrite.yaml
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
@ -92,6 +95,43 @@ ENV SERVER_NAME=localhost
EXPOSE 8008 8448 EXPOSE 8008 8448
CMD /build/run_dendrite.sh ` CMD /build/run_dendrite.sh `
const DockerfileSQLite = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build
# Copy the build context to the repo as this is the right dendrite code. This is different to the
# Complement Dockerfile which wgets a branch.
COPY . .
RUN go build ./cmd/dendrite-monolith-server
RUN go build ./cmd/generate-keys
RUN go build ./cmd/generate-config
RUN go build ./cmd/create-account
RUN ./generate-config --ci > dendrite.yaml
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
# Make sure the SQLite databases are in a persistent location, we're already mapping
# the postgresql folder so let's just use that for simplicity
RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml
# This entry script starts postgres, waits for it to be up then starts dendrite
RUN echo '\
sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\
PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\
./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\
' > run_dendrite.sh && chmod +x run_dendrite.sh
ENV SERVER_NAME=localhost
EXPOSE 8008 8448
CMD /build/run_dendrite.sh `
func dockerfile() []byte {
if *flagSqlite {
return []byte(DockerfileSQLite)
}
return []byte(DockerfilePostgreSQL)
}
const dendriteUpgradeTestLabel = "dendrite_upgrade_test" const dendriteUpgradeTestLabel = "dendrite_upgrade_test"
// downloadArchive downloads an arbitrary github archive of the form: // downloadArchive downloads an arbitrary github archive of the form:
@ -150,7 +190,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
if branchOrTagName == HEAD && *flagHead != "" { if branchOrTagName == HEAD && *flagHead != "" {
log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead)
// add top level Dockerfile // add top level Dockerfile
err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm)
if err != nil { if err != nil {
return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err)
} }
@ -166,7 +206,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
// pull an archive, this contains a top-level directory which screws with the build context // pull an archive, this contains a top-level directory which screws with the build context
// which we need to fix up post download // which we need to fix up post download
u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName)
tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to download archive %s: %w", u, err) return "", fmt.Errorf("failed to download archive %s: %w", u, err)
} }
@ -367,7 +407,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
// hit /versions to check it is up // hit /versions to check it is up
var lastErr error var lastErr error
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
res, err := http.Get(versionsURL) var res *http.Response
res, err = http.Get(versionsURL)
if err != nil { if err != nil {
lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -381,18 +422,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
lastErr = nil lastErr = nil
break break
} }
if lastErr != nil {
logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{
ShowStdout: true, ShowStdout: true,
ShowStderr: true, ShowStderr: true,
Follow: true,
}) })
// ignore errors when cannot get logs, it's just for debugging anyways // ignore errors when cannot get logs, it's just for debugging anyways
if err == nil { if err == nil {
logbody, err := io.ReadAll(logs) go func() {
if err == nil { for {
log.Printf("Container logs:\n\n%s\n\n", string(logbody)) if body, err := io.ReadAll(logs); err == nil && len(body) > 0 {
log.Printf("%s: %s", version, string(body))
} else {
return
} }
} }
}()
} }
return baseURL, containerID, lastErr return baseURL, containerID, lastErr
} }
@ -416,6 +461,46 @@ func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchTo
if err = runTests(csAPIURL, v); err != nil { if err = runTests(csAPIURL, v); err != nil {
return fmt.Errorf("failed to run tests on version %s: %s", v, err) return fmt.Errorf("failed to run tests on version %s: %s", v, err)
} }
err = testCreateAccount(dockerClient, v, containerID)
if err != nil {
return err
}
return nil
}
// test that create-account is working
func testCreateAccount(dockerClient *client.Client, v string, containerID string) error {
createUser := strings.ToLower("createaccountuser-" + v)
log.Printf("%s: Creating account %s with create-account\n", v, createUser)
respID, err := dockerClient.ContainerExecCreate(context.Background(), containerID, types.ExecConfig{
AttachStderr: true,
AttachStdout: true,
Cmd: []string{
"/build/create-account",
"-username", createUser,
"-password", "someRandomPassword",
},
})
if err != nil {
return fmt.Errorf("failed to ContainerExecCreate: %w", err)
}
response, err := dockerClient.ContainerExecAttach(context.Background(), respID.ID, types.ExecStartCheck{})
if err != nil {
return fmt.Errorf("failed to attach to container: %w", err)
}
defer response.Close()
data, err := ioutil.ReadAll(response.Reader)
if err != nil {
return err
}
if !bytes.Contains(data, []byte("AccessToken")) {
return fmt.Errorf("failed to create-account: %s", string(data))
}
return nil return nil
} }

View file

@ -48,10 +48,15 @@ func main() {
panic("unexpected key block") panic("unexpected key block")
} }
serverName := gomatrixserverlib.ServerName(*requestFrom)
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName(*requestFrom), []*gomatrixserverlib.SigningIdentity{
gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), {
privateKey, ServerName: serverName,
KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]),
PrivateKey: privateKey,
},
},
) )
u, err := url.Parse(flag.Arg(0)) u, err := url.Parse(flag.Arg(0))
@ -79,6 +84,7 @@ func main() {
req := gomatrixserverlib.NewFederationRequest( req := gomatrixserverlib.NewFederationRequest(
method, method,
serverName,
gomatrixserverlib.ServerName(u.Host), gomatrixserverlib.ServerName(u.Host),
u.RequestURI(), u.RequestURI(),
) )

View file

@ -75,7 +75,20 @@ comment. Please avoid doing this if you can.
We also have unit tests which we run via: We also have unit tests which we run via:
```bash ```bash
go test --race ./... DENDRITE_TEST_SKIP_NODB=1 go test --race ./...
```
This only runs SQLite database tests. If you wish to execute Postgres tests as well, you'll either need to
have Postgres installed locally (`createdb` will be used) or have a remote/containerized Postgres instance
available.
To configure the connection to a remote Postgres, you can use the following enviroment variables:
```bash
POSTGRES_USER=postgres
POSTGERS_PASSWORD=yourPostgresPassword
POSTGRES_HOST=localhost
POSTGRES_DB=postgres # the superuser database to use
``` ```
In general, we like submissions that come with tests. Anything that proves that the In general, we like submissions that come with tests. Anything that proves that the

View file

@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config:
handle /.well-known/matrix/server { handle /.well-known/matrix/server {
header Content-Type application/json header Content-Type application/json
header Access-Control-Allow-Origin * header Access-Control-Allow-Origin *
respond `"m.server": "matrix.example.com:8448"` respond `{"m.server": "matrix.example.com:8448"}`
} }
handle /.well-known/matrix/client { handle /.well-known/matrix/client {

View file

@ -21,8 +21,8 @@ type FederationInternalAPI interface {
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -30,7 +30,6 @@ type FederationInternalAPI interface {
request *PerformBroadcastEDURequest, request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse, response *PerformBroadcastEDUResponse,
) error ) error
PerformStoreAsync( PerformStoreAsync(
ctx context.Context, ctx context.Context,
request *PerformStoreAsyncRequest, request *PerformStoreAsyncRequest,
@ -41,6 +40,11 @@ type FederationInternalAPI interface {
request *QueryAsyncTransactionsRequest, request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse, response *QueryAsyncTransactionsResponse,
) error ) error
PerformWakeupServers(
ctx context.Context,
request *PerformWakeupServersRequest,
response *PerformWakeupServersResponse,
) error
} }
type ClientFederationAPI interface { type ClientFederationAPI interface {
@ -71,18 +75,18 @@ type RoomserverFederationAPI interface {
// containing only the server names (without information for membership events). // containing only the server names (without information for membership events).
// The response will include this server if they are joined to the room. // The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type KeyserverFederationAPI interface { type KeyserverFederationAPI interface {
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
} }
// an interface for gmsl.FederationClient - contains functions called by federationapi only. // an interface for gmsl.FederationClient - contains functions called by federationapi only.
@ -94,28 +98,28 @@ type FederationClient interface {
GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, mailserver gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error) GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, mailserver gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error)
// Perform operations // Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.
@ -214,6 +218,7 @@ type PerformInviteResponse struct {
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"` ExcludeSelf bool `json:"exclude_self"`
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
@ -244,6 +249,13 @@ type QueryAsyncTransactionsResponse struct {
RemainingCount uint32 `json:"remaining"` RemainingCount uint32 `json:"remaining"`
} }
type PerformWakeupServersRequest struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformWakeupServersResponse struct {
}
type InputPublicKeysRequest struct { type InputPublicKeysRequest struct {
Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"`
} }

View file

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return true return true
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View file

@ -18,6 +18,10 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv"
"time"
syncAPITypes "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -38,10 +42,12 @@ type OutputRoomEventConsumer struct {
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
natsClient *nats.Conn
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
topic string topic string
topicPresence string
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -49,6 +55,7 @@ func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.FederationAPI, cfg *config.FederationAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
natsClient *nats.Conn,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
@ -57,11 +64,13 @@ func NewOutputRoomEventConsumer(
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
natsClient: natsClient,
db: store, db: store,
queues: queues, queues: queues,
rsAPI: rsAPI, rsAPI: rsAPI,
durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence),
} }
} }
@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// processMessage updates the list of currently joined hosts in the room // processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event. // and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
// Ask the roomserver and add in the rest of the results into the set. // Ask the roomserver and add in the rest of the results into the set.
@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
return err return err
} }
// If we added new hosts, inform them about our known presence events for this room
if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
membership, _ := ore.Event.Membership()
if membership == gomatrixserverlib.Join {
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts)
}
}
if oldJoinedHosts == nil { if oldJoinedHosts == nil {
// This means that there is nothing to update as this is a duplicate // This means that there is nothing to update as this is a duplicate
// message. // message.
@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
) )
} }
func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) {
joined := make([]gomatrixserverlib.ServerName, len(addedJoined))
for _, added := range addedJoined {
joined = append(joined, added.ServerName)
}
// get our locally joined users
var queryRes api.QueryMembershipsForRoomResponse
err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{
JoinedOnly: true,
LocalOnly: true,
RoomID: roomID,
}, &queryRes)
if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user")
return
}
// send every presence we know about to the remote server
content := types.Presence{}
for _, ev := range queryRes.JoinEvents {
msg := nats.NewMsg(s.topicPresence)
msg.Header.Set(jetstream.UserID, ev.Sender)
var presence *nats.Msg
presence, err = s.natsClient.RequestMsg(msg, time.Second*10)
if err != nil {
log.WithError(err).Errorf("unable to get presence")
continue
}
statusMsg := presence.Header.Get("status_msg")
e := presence.Header.Get("error")
if e != "" {
continue
}
var lastActive int
lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts"))
if err != nil {
continue
}
p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)}
content.Push = append(content.Push, types.PresenceContent{
CurrentlyActive: p.CurrentlyActive(),
LastActiveAgo: p.LastActiveAgo(),
Presence: presence.Header.Get("presence"),
StatusMsg: &statusMsg,
UserID: ev.Sender,
})
}
if len(content.Push) == 0 {
return
}
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MPresence,
Origin: string(s.cfg.Matrix.ServerName),
}
if edu.Content, err = json.Marshal(content); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON")
return
}
if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil {
log.WithError(err).Error("failed to send EDU")
}
}
// joinedHostsAtEvent works out a list of matrix servers that were joined to // joinedHostsAtEvent works out a list of matrix servers that were joined to
// the room at the event (including peeking ones) // the room at the event (including peeking ones)
// It is important to use the state at the event for sending messages because: // It is important to use the state at the event for sending messages because:

View file

@ -121,21 +121,19 @@ func NewInternalAPI(
cfg.FederationMaxRetries+1, cfg.FederationMaxRetries+1,
cfg.FederationRetriesUntilAssumedOffline+1) cfg.FederationRetriesUntilAssumedOffline+1)
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := base.Cfg.Global.SigningIdentities()
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, &stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ signingInfo,
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
) )
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, queues, base.ProcessContext, cfg, js, nats, queues,
federationDB, rsAPI, federationDB, rsAPI,
) )
if err = rsConsumer.Start(); err != nil { if err = rsConsumer.Start(); err != nil {

View file

@ -104,7 +104,7 @@ func TestMain(m *testing.M) {
// Create the federation client. // Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient( s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv, s.config.Matrix.SigningIdentities(),
gomatrixserverlib.WithTransport(transport), gomatrixserverlib.WithTransport(transport),
) )
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
} }
// Get the keys and JSON-ify them. // Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config) keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ") body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv
return keys, nil return keys, nil
} }
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
res.RoomVersion = r.Version res.RoomVersion = r.Version
@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName
} }
return return
} }
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
f.fedClientMutex.Lock() f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
fedCli := gomatrixserverlib.NewFederationClient( fedCli := gomatrixserverlib.NewFederationClient(
serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, cfg.Global.SigningIdentities(),
gomatrixserverlib.WithSkipVerify(true), gomatrixserverlib.WithSkipVerify(true),
) )
@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
t.Errorf("failed to create invite v2 request: %s", err) t.Errorf("failed to create invite v2 request: %s", err)
continue continue
} }
_, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq)
if err == nil { if err == nil {
t.Errorf("expected an error, got none") t.Errorf("expected an error, got none")
continue continue

View file

@ -11,13 +11,13 @@ import (
// client. // client.
func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res gomatrixserverlib.RespEventAuth, err error) { ) (res gomatrixserverlib.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespEventAuth{}, err return gomatrixserverlib.RespEventAuth{}, err
@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth(
} }
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespUserDevices{}, err return gomatrixserverlib.RespUserDevices{}, err
@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices(
} }
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespClaimKeys{}, err return gomatrixserverlib.RespClaimKeys{}, err
@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys(
} }
func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, s, keys) return a.federation.QueryKeys(ctx, origin, s, keys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespQueryKeys{}, err return gomatrixserverlib.RespQueryKeys{}, err
@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys(
} }
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill(
} }
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) { ) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespState{}, err return gomatrixserverlib.RespState{}, err
@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState(
} }
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) { ) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespStateIDs{}, err return gomatrixserverlib.RespStateIDs{}, err
@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs(
} }
func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespMissingEvents{}, err return gomatrixserverlib.RespMissingEvents{}, err
@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents(
} }
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys(
} }
func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
if err != nil { if err != nil {
return res, err return res, err
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err

View file

@ -32,6 +32,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
dir, err := r.federation.LookupRoomAlias( dir, err := r.federation.LookupRoomAlias(
ctx, ctx,
r.cfg.Matrix.ServerName,
request.ServerName, request.ServerName,
request.RoomAlias, request.RoomAlias,
) )
@ -154,10 +155,16 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes") return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes")
} }
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
}
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
respMakeJoin, err := r.federation.MakeJoin( respMakeJoin, err := r.federation.MakeJoin(
ctx, ctx,
origin,
serverName, serverName,
roomID, roomID,
userID, userID,
@ -203,7 +210,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Build the join event. // Build the join event.
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
@ -215,6 +222,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
context.Background(), context.Background(),
origin,
serverName, serverName,
event, event,
) )
@ -257,7 +265,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
r.keyRing, r.keyRing,
event, event,
federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
) )
if err != nil { if err != nil {
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
@ -292,6 +300,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,
origin,
roomserverAPI.KindNew, roomserverAPI.KindNew,
respState, respState,
event.Headered(respMakeJoin.RoomVersion), event.Headered(respMakeJoin.RoomVersion),
@ -443,6 +452,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// request. // request.
respPeek, err := r.federation.Peek( respPeek, err := r.federation.Peek(
ctx, ctx,
r.cfg.Matrix.ServerName,
serverName, serverName,
roomID, roomID,
peekID, peekID,
@ -469,7 +479,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName))
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
@ -491,7 +501,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// logrus.Warnf("got respPeek %#v", respPeek) // logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view. // Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI, ctx, r.rsAPI, r.cfg.Matrix.ServerName,
roomserverAPI.KindNew, roomserverAPI.KindNew,
&respState, &respState,
respPeek.LatestEvent.Headered(respPeek.RoomVersion), respPeek.LatestEvent.Headered(respPeek.RoomVersion),
@ -511,6 +521,11 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
// Deduplicate the server names we were provided. // Deduplicate the server names we were provided.
util.SortAndUnique(request.ServerNames) util.SortAndUnique(request.ServerNames)
@ -526,6 +541,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin,
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -567,7 +583,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := respMakeLeave.LeaveEvent.Build( event, err := respMakeLeave.LeaveEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeLeave.RoomVersion, respMakeLeave.RoomVersion,
@ -580,6 +596,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin,
serverName, serverName,
event, event,
) )
@ -606,6 +623,11 @@ func (r *FederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender())
if err != nil {
return err
}
if request.Event.StateKey() == nil { if request.Event.StateKey() == nil {
return errors.New("invite must be a state event") return errors.New("invite must be a state event")
} }
@ -633,7 +655,7 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq)
if err != nil { if err != nil {
return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
} }
@ -728,9 +750,23 @@ func (r *FederationInternalAPI) QueryAsyncTransactions(
return nil return nil
} }
// PerformWakeupServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) (err error) {
r.MarkServersAlive(request.ServerNames)
return nil
}
func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) {
for _, srv := range destinations { for _, srv := range destinations {
// Check the statistics cache for the blacklist status to prevent hitting
// the database unnecessarily.
if r.queues.IsServerBlacklisted(srv) {
_ = r.db.RemoveServerFromBlacklist(srv) _ = r.db.RemoveServerFromBlacklist(srv)
}
r.queues.RetryServer(srv) r.queues.RetryServer(srv)
} }
} }
@ -788,7 +824,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation api.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -818,7 +854,7 @@ func federatedAuthProvider(
// Try to retrieve the event from the server that sent us the send // Try to retrieve the event from the server that sent us the send
// join response. // join response.
tx, txerr := federation.GetEvent(ctx, server, eventID) tx, txerr := federation.GetEvent(ctx, origin, server, eventID)
if txerr != nil { if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
} }

View file

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
if err != nil { if err != nil {
return return
} }

View file

@ -25,6 +25,7 @@ const (
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync" FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions" FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
@ -174,18 +175,32 @@ func (h *httpFederationInternalAPI) QueryAsyncTransactions(
) )
} }
// Handle an instruction to remove the respective servers from being blacklisted.
func (h *httpFederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers,
h.httpClient, ctx, request, response,
)
}
type getUserDevices struct { type getUserDevices struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
UserID string UserID string
} }
func (h *httpFederationInternalAPI) GetUserDevices( func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{ ctx, &getUserDevices{
S: s, S: s,
Origin: origin,
UserID: userID, UserID: userID,
}, },
) )
@ -193,16 +208,18 @@ func (h *httpFederationInternalAPI) GetUserDevices(
type claimKeys struct { type claimKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string OneTimeKeys map[string]map[string]string
} }
func (h *httpFederationInternalAPI) ClaimKeys( func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{ ctx, &claimKeys{
S: s, S: s,
Origin: origin,
OneTimeKeys: oneTimeKeys, OneTimeKeys: oneTimeKeys,
}, },
) )
@ -210,16 +227,18 @@ func (h *httpFederationInternalAPI) ClaimKeys(
type queryKeys struct { type queryKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Keys map[string][]string Keys map[string][]string
} }
func (h *httpFederationInternalAPI) QueryKeys( func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{ ctx, &queryKeys{
S: s, S: s,
Origin: origin,
Keys: keys, Keys: keys,
}, },
) )
@ -227,18 +246,20 @@ func (h *httpFederationInternalAPI) QueryKeys(
type backfill struct { type backfill struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Limit int Limit int
EventIDs []string EventIDs []string
} }
func (h *httpFederationInternalAPI) Backfill( func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{ ctx, &backfill{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Limit: limit, Limit: limit,
EventIDs: eventIDs, EventIDs: eventIDs,
@ -248,18 +269,20 @@ func (h *httpFederationInternalAPI) Backfill(
type lookupState struct { type lookupState struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupState( func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) { ) (gomatrixserverlib.RespState, error) {
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{ ctx, &lookupState{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -269,17 +292,19 @@ func (h *httpFederationInternalAPI) LookupState(
type lookupStateIDs struct { type lookupStateIDs struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) LookupStateIDs( func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) { ) (gomatrixserverlib.RespStateIDs, error) {
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{ ctx, &lookupStateIDs{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
}, },
@ -288,19 +313,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs(
type lookupMissingEvents struct { type lookupMissingEvents struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Missing gomatrixserverlib.MissingEvents Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupMissingEvents( func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{ ctx, &lookupMissingEvents{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Missing: missing, Missing: missing,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -310,16 +337,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents(
type getEvent struct { type getEvent struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEvent( func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{ ctx, &getEvent{
S: s, S: s,
Origin: origin,
EventID: eventID, EventID: eventID,
}, },
) )
@ -327,19 +356,21 @@ func (h *httpFederationInternalAPI) GetEvent(
type getEventAuth struct { type getEventAuth struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEventAuth( func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) { ) (gomatrixserverlib.RespEventAuth, error) {
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{ ctx, &getEventAuth{
S: s, S: s,
Origin: origin,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
@ -375,18 +406,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys(
type eventRelationships struct { type eventRelationships struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion RoomVer gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) MSC2836EventRelationships( func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{ ctx, &eventRelationships{
S: s, S: s,
Origin: origin,
Req: r, Req: r,
RoomVer: roomVersion, RoomVer: roomVersion,
}, },
@ -395,17 +428,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
SuggestedOnly bool SuggestedOnly bool
RoomID string RoomID string
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{ ctx, &spacesReq{
S: dst, S: dst,
Origin: origin,
SuggestedOnly: suggestedOnly, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
}, },

View file

@ -53,6 +53,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions), httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
) )
internalAPIMux.Handle(
FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath, FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI( httputil.MakeInternalRPCAPI(
@ -69,7 +74,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices", "FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -80,7 +85,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys", "FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -91,7 +96,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys", "FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -102,7 +107,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIBackfill", "FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -113,7 +118,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupState", "FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -124,7 +129,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs", "FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -135,7 +140,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents", "FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -146,7 +151,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent", "FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID) res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -157,7 +162,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth", "FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -184,7 +189,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships", "FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -195,7 +200,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary", "FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),

View file

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests
@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() {
} }
} }
// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are
// pending events or if forceWakeup is true. This prevents starting the
// queue unnecessarily.
func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) {
eventsPending := func() bool {
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0
}
// NOTE : Only wakeup and notify the queue if there are pending events
// or if forceWakeup is true. Otherwise there is no reason to start the
// queue goroutine and waste resources.
if forceWakeup || eventsPending() {
logrus.Info("Starting queue due to pending events or forceWakeup")
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it // wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work. // that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() { func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep. // NOTE : Send notification before waking queue to prevent a race
oq.wakeQueueIfNeeded() // where the queue was running and stops due to a timeout in between
// checking it and sending the notification.
// Notify the queue that there are events ready to send. // Notify the queue that there are events ready to send.
select { select {
case oq.notify <- struct{}{}: case oq.notify <- struct{}{}:
default: default:
} }
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
} }
// wakeQueueIfNeeded will wake up the destination queue if it is // wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off // not already running.
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) wakeQueueIfNeeded() {
// Clear the backingOff flag and update the backoff metrics if it was set. // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) { if oq.backingOff.CompareAndSwap(true, false) {

View file

@ -15,7 +15,6 @@
package queue package queue
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync" "sync"
@ -46,7 +45,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client fedapi.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
@ -91,7 +90,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient, client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing []*gomatrixserverlib.SigningIdentity,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
disabled: disabled, disabled: disabled,
@ -101,9 +100,12 @@ func NewOutgoingQueues(
origin: origin, origin: origin,
client: client, client: client,
statistics: statistics, statistics: statistics,
signing: signing, signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{},
queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
} }
for _, identity := range signing {
queues.signing[identity.ServerName] = identity
}
// Look up which servers we have pending items for and then rehydrate those queues. // Look up which servers we have pending items for and then rehydrate those queues.
if !disabled { if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{} serverNames := map[gomatrixserverlib.ServerName]struct{}{}
@ -135,14 +137,6 @@ func NewOutgoingQueues(
return queues return queues
} }
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
// around together
type SigningInfo struct {
ServerName gomatrixserverlib.ServerName
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
}
type queuedPDU struct { type queuedPDU struct {
receipt *shared.Receipt receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent pdu *gomatrixserverlib.HeaderedEvent
@ -199,11 +193,10 @@ func (oqs *OutgoingQueues) SendEvent(
log.Trace("Federation is disabled, not sending event") log.Trace("Federation is disabled, not sending event")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -214,7 +207,9 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
@ -288,11 +283,10 @@ func (oqs *OutgoingQueues) SendEDU(
log.Trace("Federation is disabled, not sending EDU") log.Trace("Federation is disabled, not sending EDU")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -303,7 +297,9 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// There is absolutely no guarantee that the EDU will have a room_id // There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does* // field, as it is not required by the spec. However, if it *does*
@ -378,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU(
return nil return nil
} }
// IsServerBlacklisted returns whether or not the provided server is currently
// blacklisted.
func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool {
return oqs.statistics.ForServer(srv).Blacklisted()
}
// RetryServer attempts to resend events to the given server if we had given up. // RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
oqs.statistics.ForServer(srv).RemoveBlacklist()
serverStatistics := oqs.statistics.ForServer(srv)
forceWakeup := serverStatistics.Blacklisted()
serverStatistics.RemoveBlacklist()
serverStatistics.ClearBackoff()
if queue := oqs.getQueue(srv); queue != nil { if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff() queue.wakeQueueIfEventsPending(forceWakeup)
queue.wakeQueueIfNeeded()
} }
} }

View file

@ -437,12 +437,13 @@ func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32
} }
rs := &stubFederationRoomServerAPI{} rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics( stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline)
db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{
signingInfo := &SigningInfo{ {
KeyID: "ed21019:auto", KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA, PrivateKey: test.PrivateKeyA,
ServerName: "localhost", ServerName: "localhost",
},
} }
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)

View file

@ -83,6 +83,7 @@ func Backfill(
"": eIDs, "": eIDs,
}, },
ServerName: request.Origin(), ServerName: request.Origin(),
VirtualHost: request.Destination(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
@ -123,7 +124,7 @@ func Backfill(
} }
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: request.Destination(),
PDUs: eventJSONs, PDUs: eventJSONs,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -123,7 +123,7 @@ func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib
txn := createTransaction() txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1" var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Destination, path) request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn) request.SetContent(txn)
return txn, request return txn, request

View file

@ -140,6 +140,21 @@ func processInvite(
} }
} }
if event.StateKey() == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The invite event has no state key"),
}
}
_, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)),
}
}
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
@ -175,7 +190,7 @@ func processInvite(
// Sign the event so that other servers will know that we have received the invite. // Sign the event so that other servers will know that we have received the invite.
signedEvent := event.Sign( signedEvent := event.Sign(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
) )
// Add the invite event to the roomserver. // Add the invite event to the roomserver.

View file

@ -131,10 +131,20 @@ func MakeJoin(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion, RoomVersion: verRes.RoomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -134,28 +134,30 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, serverName)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.MessageResponse(http.StatusNotFound, err.Error())
} }
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
var identity *gomatrixserverlib.SigningIdentity
keys.ServerName = cfg.Matrix.ServerName var err error
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil {
if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil {
return nil, err
}
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = cfg.Matrix.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: { cfg.Matrix.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey), Key: gomatrixserverlib.Base64Bytes(publicKey),
}, },
} }
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
@ -165,6 +167,21 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
ExpiredTS: oldVerifyKey.ExpiredAt, ExpiredTS: oldVerifyKey.ExpiredAt,
} }
} }
} else {
if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil {
return nil, err
}
publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = virtualHost.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
virtualHost.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
// TODO: Virtual hosts probably want to be able to specify old signing
// keys too, just in case
}
toSign, err := json.Marshal(keys.ServerKeyFields) toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil { if err != nil {
@ -172,13 +189,9 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
} }
keys.Raw, err = gomatrixserverlib.SignJSON( keys.Raw, err = gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign,
) )
if err != nil { return &keys, err
return nil, err
}
return &keys, nil
} }
func NotaryKeys( func NotaryKeys(
@ -186,6 +199,14 @@ func NotaryKeys(
fsAPI federationAPI.FederationInternalAPI, fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest, req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse { ) util.JSONResponse {
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
if !cfg.Matrix.IsLocalServerName(serverName) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Server name not known"),
}
}
if req == nil { if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
@ -201,7 +222,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys { for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { if k, err := localKeys(cfg, serverName); err == nil {
keyList = append(keyList, *k) keyList = append(keyList, *k)
} else { } else {
return util.ErrorResponse(err) return util.ErrorResponse(err)

View file

@ -13,6 +13,7 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"time" "time"
@ -60,8 +61,18 @@ func MakeLeave(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -42,16 +41,9 @@ func GetProfile(
} }
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)),

View file

@ -83,7 +83,7 @@ func RoomAliasToID(
} }
} }
} else { } else {
resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:

View file

@ -74,7 +74,7 @@ func Setup(
} }
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg) return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {

View file

@ -197,12 +197,12 @@ type txnReq struct {
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface { type txnFederationClient interface {
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) )
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion), event.Headered(roomVersion),
}, },
t.Destination,
t.Origin, t.Origin,
api.DoNotSendToOtherServers, api.DoNotSendToOtherServers,
nil, nil,

View file

@ -147,7 +147,7 @@ type txnFedClient struct {
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
} }
func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) { ) {
fmt.Println("testFederationClient.LookupState", eventID) fmt.Println("testFederationClient.LookupState", eventID)
@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv
res = r res = r
return return
} }
func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID) fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID] r, ok := c.stateIDs[eventID]
if !ok { if !ok {
@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S
res = r res = r
return return
} }
func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID) fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID] r, ok := c.getEvent[eventID]
if !ok { if !ok {
@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN
res = r res = r
return return
} }
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing) return c.getMissingEvents(missing)
} }

View file

@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites(
} }
// Send all the events // Send all the events
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { if err := api.SendEvents(
req.Context(),
rsAPI,
api.KindNew,
evs,
cfg.Matrix.ServerName, // TODO: which virtual host?
"TODO",
cfg.Matrix.ServerName,
nil,
false,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()),
}
}
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey)
if err != nil { if err != nil {
@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite(
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
inviteEvent.Headered(verRes.RoomVersion), inviteEvent.Headered(verRes.RoomVersion),
}, },
request.Destination(),
request.Origin(), request.Origin(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,
@ -341,7 +360,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation federationAPI.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, cfg *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)
@ -357,7 +376,7 @@ func sendToRemoteServer(
} }
for _, server := range remoteServers { for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(ctx, server, builder) err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder)
if err == nil { if err == nil {
return return
} }

View file

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
} }
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
return
}
return return
} }
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) stmt := s.selectJoinedHostsForRoomsStmt
if excludingBlacklisted {
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,10 +70,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -122,15 +122,17 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if excludeSelf { if excludeSelf {
for i, server := range servers { for i, server := range servers {
if d.IsLocalServerName(server) { if d.IsLocalServerName(server) {
servers = append(servers[:i], servers[i+1:]...) copy(servers[i:], servers[i+1:])
servers = servers[:len(servers)-1]
break
} }
} }
} }

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs { for i := range roomIDs {
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
query := selectJoinedHostsForRoomsSQL
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) if excludingBlacklisted {
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
}
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -35,6 +35,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -59,10 +63,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -71,7 +71,7 @@ type FederationJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
} }
type FederationBlacklist interface { type FederationBlacklist interface {

10
go.mod
View file

@ -8,7 +8,7 @@ require (
github.com/Masterminds/semver/v3 v3.1.1 github.com/Masterminds/semver/v3 v3.1.1
github.com/blevesearch/bleve/v2 v2.3.4 github.com/blevesearch/bleve/v2 v2.3.4
github.com/codeclysm/extract v2.2.0+incompatible github.com/codeclysm/extract v2.2.0+incompatible
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d github.com/dgraph-io/ristretto v0.1.1
github.com/docker/docker v20.10.19+incompatible github.com/docker/docker v20.10.19+incompatible
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/getsentry/sentry-go v0.14.0 github.com/getsentry/sentry-go v0.14.0
@ -22,12 +22,12 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221202221557-069343f050b8 github.com/matrix-org/gomatrixserverlib v0.0.0-20221202231116-4b30931352d0
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.4 github.com/nats-io/nats-server/v2 v2.9.8
github.com/nats-io/nats.go v1.19.0 github.com/nats-io/nats.go v1.20.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79

22
go.sum
View file

@ -141,8 +141,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68=
@ -348,10 +348,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221202221557-069343f050b8 h1:tlCuFX+80b3CN4GljEf8GKIDtjGxa73RTR3X+/H5IrU= github.com/matrix-org/gomatrixserverlib v0.0.0-20221202231116-4b30931352d0 h1:C+R7wsFgk8zgMm8esnGsa4+854IutZWfcP1p2Q8KqzE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221202221557-069343f050b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221202231116-4b30931352d0/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o=
github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g=
github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM=
github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -527,7 +527,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -698,6 +697,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View file

@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist")
// Returns an error if something else went wrong // Returns an error if something else went wrong
func QueryAndBuildEvent( func QueryAndBuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
if queryRes == nil { if queryRes == nil {
@ -50,24 +51,24 @@ func QueryAndBuildEvent(
// This can pass through a ErrRoomNoExists to the caller // This can pass through a ErrRoomNoExists to the caller
return nil, err return nil, err
} }
return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes)
} }
// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse
// provided. // provided.
func BuildEvent( func BuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil {
if err != nil {
return nil, err return nil, err
} }
event, err := builder.Build( event, err := builder.Build(
evTime, cfg.ServerName, cfg.KeyID, evTime, identity.ServerName, identity.KeyID,
cfg.PrivateKey, queryRes.RoomVersion, identity.PrivateKey, queryRes.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -42,10 +42,26 @@ type BasicAuth struct {
Password string `yaml:"password"` Password string `yaml:"password"`
} }
type AuthAPIOpts struct {
GuestAccessAllowed bool
}
// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify
// the user is allowed to do specific things.
type AuthAPIOption func(opts *AuthAPIOpts)
// WithAllowGuests checks that guest users have access to this endpoint
func WithAllowGuests() AuthAPIOption {
return func(opts *AuthAPIOpts) {
opts.GuestAccessAllowed = true
}
}
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI( func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI, metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse, f func(*http.Request, *userapi.Device) util.JSONResponse,
checks ...AuthAPIOption,
) http.Handler { ) http.Handler {
h := func(req *http.Request) util.JSONResponse { h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -76,6 +92,19 @@ func MakeAuthAPI(
} }
}() }()
// apply additional checks, if any
opts := AuthAPIOpts{}
for _, opt := range checks {
opt(&opts)
}
if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"),
}
}
jsonRes := f(req, device) jsonRes := f(req, device)
// do not log 4xx as errors as they are client fails, not server fails // do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 { if hub != nil && jsonRes.Code >= 500 {

View file

@ -9,6 +9,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
) )
@ -50,8 +52,7 @@ func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest,
return err return err
} }
//nolint:errcheck defer internal.CloseAndLogIfError(ctx, hresp.Body, "failed to close response body")
defer hresp.Body.Close()
if hresp.StatusCode == http.StatusOK { if hresp.StatusCode == http.StatusOK {
return json.NewDecoder(hresp.Body).Decode(resp) return json.NewDecoder(hresp.Body).Decode(resp)

View file

@ -0,0 +1,54 @@
package pushgateway
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func TestNotify(t *testing.T) {
wantResponse := NotifyResponse{
Rejected: []string{"testing"},
}
var i = 0
svr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// /notify only accepts POST requests
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotImplemented)
return
}
if i != 0 { // error path
w.WriteHeader(http.StatusBadRequest)
return
}
// happy path
json.NewEncoder(w).Encode(wantResponse)
}))
defer svr.Close()
cl := NewHTTPClient(true)
gotResponse := NotifyResponse{}
// Test happy path
err := cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse)
if err != nil {
t.Errorf("failed to notify client")
}
if !reflect.DeepEqual(gotResponse, wantResponse) {
t.Errorf("expected response %+v, got %+v", wantResponse, gotResponse)
}
// Test error path
i++
err = cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse)
if err == nil {
t.Errorf("expected notifying the pushgateway to fail, but it succeeded")
}
}

View file

@ -145,6 +145,11 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua
} }
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
// It doesn't make sense for an empty pattern to match anything.
if pattern == "" {
return false, nil
}
re, err := globToRegexp(pattern) re, err := globToRegexp(pattern)
if err != nil { if err != nil {
return false, err return false, err
@ -154,12 +159,20 @@ func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool,
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
return false, fmt.Errorf("parsing event: %w", err) return false, fmt.Errorf("parsing event: %w", err)
} }
// From the spec:
// "If the property specified by key is completely absent from
// the event, or does not have a string value, then the condition
// will not match, even if pattern is *."
v, err := lookupMapPath(strings.Split(key, "."), eventMap) v, err := lookupMapPath(strings.Split(key, "."), eventMap)
if err != nil { if err != nil {
// An unknown path is a benign error that shouldn't stop rule // An unknown path is a benign error that shouldn't stop rule
// processing. It's just a non-match. // processing. It's just a non-match.
return false, nil return false, nil
} }
if _, ok := v.(string); !ok {
// A non-string never matches.
return false, nil
}
return re.MatchString(fmt.Sprint(v)), nil return re.MatchString(fmt.Sprint(v)), nil
} }

View file

@ -111,7 +111,10 @@ func TestConditionMatches(t *testing.T) {
{"empty", Condition{}, `{}`, false}, {"empty", Condition{}, `{}`, false},
{"empty", Condition{Kind: "unknownstring"}, `{}`, false}, {"empty", Condition{Kind: "unknownstring"}, `{}`, false},
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, // Neither of these should match because `content` is not a full string match,
// and `content.body` is not a string value.
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false},
{"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false},
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
@ -137,7 +140,7 @@ func TestConditionMatches(t *testing.T) {
t.Fatalf("conditionMatches failed: %v", err) t.Fatalf("conditionMatches failed: %v", err)
} }
if got != tst.Want { if got != tst.Want {
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name)
} }
}) })
} }
@ -161,9 +164,7 @@ func TestPatternMatches(t *testing.T) {
}{ }{
{"empty", "", "", `{}`, false}, {"empty", "", "", `{}`, false},
// Note that an empty pattern contains no wildcard characters, {"patternEmpty", "content", "", `{"content":{}}`, false},
// which implicitly means "*".
{"patternEmpty", "content", "", `{"content":{}}`, true},
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
@ -178,7 +179,7 @@ func TestPatternMatches(t *testing.T) {
t.Fatalf("patternMatches failed: %v", err) t.Fatalf("patternMatches failed: %v", err)
} }
if got != tst.Want { if got != tst.Want {
t.Errorf("patternMatches: got %v, want %v", got, tst.Want) t.Errorf("patternMatches: got %v, want %v on %s", got, tst.Want, tst.Name)
} }
}) })
} }

View file

@ -11,22 +11,27 @@ import (
// kind and a tweaks map. Returns a nil map if it would have been // kind and a tweaks map. Returns a nil map if it would have been
// empty. // empty.
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
kind := UnknownAction var kind ActionKind
tweaks := map[string]interface{}{} var tweaks map[string]interface{}
for _, a := range as { for _, a := range as {
if a.Kind == SetTweakAction { switch a.Kind {
tweaks[string(a.Tweak)] = a.Value case DontNotifyAction:
continue // Don't bother processing any further
return DontNotifyAction, nil, nil
case SetTweakAction:
if tweaks == nil {
tweaks = map[string]interface{}{}
} }
tweaks[string(a.Tweak)] = a.Value
default:
if kind != UnknownAction { if kind != UnknownAction {
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
} }
kind = a.Kind kind = a.Kind
} }
if len(tweaks) == 0 {
tweaks = nil
} }
return kind, tweaks, nil return kind, tweaks, nil

View file

@ -17,6 +17,7 @@ func TestActionsToTweaks(t *testing.T) {
{"empty", nil, UnknownAction, nil}, {"empty", nil, UnknownAction, nil},
{"zero", []*Action{{}}, UnknownAction, nil}, {"zero", []*Action{{}}, UnknownAction, nil},
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
{"onlyPrimaryDontNotify", []*Action{{Kind: DontNotifyAction}}, DontNotifyAction, nil},
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
{ {

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 10 VersionMinor = 10
VersionPatch = 6 VersionPatch = 8
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -35,7 +35,7 @@ type DeviceListUpdateConsumer struct {
durable string durable string
topic string topic string
updater *internal.DeviceListUpdater updater *internal.DeviceListUpdater
serverName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
@ -51,7 +51,7 @@ func NewDeviceListUpdateConsumer(
durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
updater: updater, updater: updater,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
return true return true
} else if serverName == t.serverName { } else if t.isLocalServerName(serverName) {
return true return true
} else if serverName != origin { } else if serverName != origin {
return true return true

View file

@ -37,6 +37,7 @@ type SigningKeyUpdateConsumer struct {
topic string topic string
keyAPI *internal.KeyInternalAPI keyAPI *internal.KeyInternalAPI
cfg *config.KeyServer cfg *config.KeyServer
isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
@ -53,6 +54,7 @@ func NewSigningKeyUpdateConsumer(
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
keyAPI: keyAPI, keyAPI: keyAPI,
cfg: cfg, cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
logrus.WithError(err).Error("failed to split user id") logrus.WithError(err).Error("failed to split user id")
return true return true
} else if serverName == t.cfg.Matrix.ServerName { } else if t.isLocalServerName(serverName) {
logrus.Warn("dropping device key update from ourself") logrus.Warn("dropping device key update from ourself")
return true return true
} else if serverName != origin { } else if serverName != origin {

View file

@ -47,7 +47,6 @@ var (
) )
) )
const defaultWaitTime = time.Minute
const requestTimeout = time.Second * 30 const requestTimeout = time.Second * 30
func init() { func init() {
@ -97,6 +96,7 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.KeyserverFederationAPI fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select. // block on or timeout via a select.
@ -140,6 +140,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process, process: process,
@ -149,6 +150,7 @@ func NewDeviceListUpdater(
api: api, api: api,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
thisServer: thisServer,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool), userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{}, userIDToChanMu: &sync.Mutex{},
@ -436,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
"server_name": serverName, "server_name": serverName,
"user_id": userID, "user_id": userID,
}) })
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err return time.Minute * 10, err
@ -454,7 +455,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
} else if e.Code >= 300 { } else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
return time.Hour, err return hourWaitTime, err
} }
case net.Error: case net.Error:
// Use the default waitTime, if it's a timeout. // Use the default waitTime, if it's a timeout.
@ -468,7 +469,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
// This is to avoid spamming remote servers, which may not be Matrix servers anymore. // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 { if e.Code >= 300 {
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
return time.Hour, err return hourWaitTime, err
} }
default: default:
// Something else failed // Something else failed

View file

@ -0,0 +1,22 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !vw
package internal
import "time"
const defaultWaitTime = time.Minute
const hourWaitTime = time.Hour

View file

@ -0,0 +1,25 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build vw
package internal
import "time"
// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
// results in a one-hour wait time from a previous device so the test times out. This is fine for
// production, but makes an otherwise passing test fail.
const defaultWaitTime = time.Second
const hourWaitTime = time.Second

View file

@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil) _, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient( fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, []*gomatrixserverlib.SigningIdentity{
{
ServerName: gomatrixserverlib.ServerName("example.test"),
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
PrivateKey: pkey,
},
},
) )
fedClient.Client = *gomatrixserverlib.NewClient( fedClient.Client = *gomatrixserverlib.NewClient(
gomatrixserverlib.WithTransport(&roundTripper{tripper}), gomatrixserverlib.WithTransport(&roundTripper{tripper}),
@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -33,12 +33,13 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
nested[userID] = val nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested domainToDeviceKeys[string(serverName)] = nested
} }
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys // claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local) keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
} }
} }
delete(domainToDeviceKeys, string(a.ThisServer)) delete(domainToDeviceKeys, domain)
} }
if len(domainToDeviceKeys) > 0 { if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
@ -142,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
defer wg.Done() defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{} domains := map[string]struct{}{}
for domain := range domainToDeviceKeys { for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
} }
for domain := range domainToCrossSigningKeys { for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
@ -555,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 { if len(devKeys) == 0 {
return return
} }
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil { if err == nil {
resultCh <- &queryKeysResp resultCh <- &queryKeysResp
return return
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
if serverName != a.ThisServer { if !a.Cfg.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users continue // ignore remote users
} }
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {

View file

@ -54,11 +54,11 @@ func NewInternalAPI(
} }
ap := &internal.KeyInternalAPI{ ap := &internal.KeyInternalAPI{
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, Cfg: cfg,
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {

View file

@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4" " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
@ -77,7 +77,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err return err
} }

View file

@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4" " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
db *sql.DB db *sql.DB
@ -80,7 +80,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err return err
} }

Some files were not shown because too many files have changed in this diff Show more