Merge upstream 0.10.8

This commit is contained in:
Daniel Aloni 2023-01-09 09:46:00 +02:00
parent 6cc9ea3642
commit 3b1d3b75ea
166 changed files with 3064 additions and 1329 deletions

View file

@ -14,6 +14,43 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
wasm:
name: WASM build test
timeout-minutes: 5
runs-on: ubuntu-latest
if: ${{ false }} # disable for now
steps:
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.18
cache: true
- name: Install Node
uses: actions/setup-node@v2
with:
node-version: 14
- uses: actions/cache@v3
with:
path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Reconfigure Git to use HTTPS auth for repo packages
run: >
git config --global url."https://github.com/".insteadOf
ssh://git@github.com/
- name: Install test dependencies
working-directory: ./test/wasm
run: npm ci
- name: Test
run: ./test-dendritejs.sh
# Run golangci-lint # Run golangci-lint
lint: lint:
@ -64,19 +101,12 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
cache: true
- name: Set up gotestfmt - name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2 uses: gotesttools/gotestfmt-action@v2
with: with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting. # Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test -json -v ./... 2>&1 | gotestfmt - run: go test -json -v ./... 2>&1 | gotestfmt
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
@ -101,17 +131,17 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
~/go/pkg/mod ~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- env: - env:
GOOS: ${{ matrix.goos }} GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
@ -119,6 +149,39 @@ jobs:
CGO_CFLAGS: -fno-stack-protector CGO_CFLAGS: -fno-stack-protector
run: go build -trimpath -v -o "bin/" ./cmd/... run: go build -trimpath -v -o "bin/" ./cmd/...
# build for Windows 64-bit
build_windows:
name: Build for Windows
timeout-minutes: 10
runs-on: ubuntu-latest
strategy:
matrix:
go: ["1.18", "1.19"]
goos: ["windows"]
goarch: ["amd64"]
steps:
- uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- env:
GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 1
CC: "/usr/bin/x86_64-w64-mingw32-gcc"
run: go build -trimpath -v -o "bin/" ./cmd/...
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
initial-tests-done: initial-tests-done:
name: Initial tests passed name: Initial tests passed
@ -151,6 +214,8 @@ jobs:
image: matrixdotorg/sytest-dendrite:latest image: matrixdotorg/sytest-dendrite:latest
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build
- /root/.cache/go-mod:/gopath/pkg/mod
env: env:
POSTGRES: ${{ matrix.postgres && 1}} POSTGRES: ${{ matrix.postgres && 1}}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
@ -158,6 +223,14 @@ jobs:
CGO_ENABLED: ${{ matrix.cgo && 1 }} CGO_ENABLED: ${{ matrix.cgo && 1 }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
/gopath/pkg/mod
key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-sytest-
- name: Run Sytest - name: Run Sytest
run: /bootstrap.sh dendrite run: /bootstrap.sh dendrite
working-directory: /src working-directory: /src
@ -192,10 +265,12 @@ jobs:
include: include:
- label: PostgreSQL - label: PostgreSQL
postgres: Postgres postgres: Postgres
cgo: 0
- label: PostgreSQL, full HTTP APIs - label: PostgreSQL, full HTTP APIs
postgres: Postgres postgres: Postgres
api: full-http api: full-http
cgo: 0
steps: steps:
# Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement.
# See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path
@ -203,14 +278,12 @@ jobs:
run: | run: |
echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH
echo "~/go/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH
- name: "Install Complement Dependencies" - name: "Install Complement Dependencies"
# We don't need to install Go because it is included on the Ubuntu 20.04 image: # We don't need to install Go because it is included on the Ubuntu 20.04 image:
# See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64
run: | run: |
sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev
go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest
- name: Run actions/checkout@v3 for dendrite - name: Run actions/checkout@v3 for dendrite
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -239,9 +312,8 @@ jobs:
(wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break (wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
done done
# Build initial Dendrite image # Build initial Dendrite image
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
working-directory: dendrite working-directory: dendrite
env: env:
DOCKER_BUILDKIT: 1 DOCKER_BUILDKIT: 1
@ -253,9 +325,8 @@ jobs:
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
env: env:
COMPLEMENT_BASE_IMAGE: complement-dendrite:latest COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
CGO_ENABLED: ${{ matrix.cgo && 1 }}
working-directory: complement working-directory: complement
integration-tests-done: integration-tests-done:

View file

@ -45,6 +45,11 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2
with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
@ -53,12 +58,14 @@ jobs:
key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-race- ${{ runner.os }}-go${{ matrix.go }}-test-race-
- run: go test -race ./... - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite POSTGRES_DB: dendrite
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
initial-tests-done: initial-tests-done:

View file

@ -1,5 +1,28 @@
# Changelog # Changelog
## Dendrite 0.10.8 (2022-11-29)
### Features
* The built-in NATS Server has been updated to version 2.9.8
* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment)
### Fixes
* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly
* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory
* The sync API will no longer filter out the user's own membership when using lazy-loading
* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed
* A panic in the federation API where the server list could go out of bounds has been fixed
* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests
* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet
* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely
* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely
* Missing state key NIDs will now be allocated on request rather than returning an error
* Guest access is now correctly denied on a number of endpoints
* Presence information will now be correctly sent for new private chats
* A number of unspecced fields have been removed from outbound `/send` transactions
## Dendrite 0.10.7 (2022-11-04) ## Dendrite 0.10.7 (2022-11-04)
### Features ### Features

View file

@ -4,18 +4,16 @@ RUN apk --update --no-cache add bash build-base
WORKDIR /build WORKDIR /build
COPY . /build #
# The dendrite base image
RUN mkdir -p bin #
RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server FROM alpine:latest AS dendrite-base
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest
LABEL org.opencontainers.image.title="Dendrite (Monolith)" LABEL org.opencontainers.image.title="Dendrite (Monolith)"
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.licenses="Apache-2.0"
LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/"
LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C."
COPY --from=base /build/bin/* /usr/bin/ COPY --from=base /build/bin/* /usr/bin/

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out. // events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices { for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount( func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI, userAPI userapi.AppserviceUserAPI,
as config.ApplicationService, as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken, AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart, DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart,

View file

@ -40,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
@ -58,6 +59,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/matrix-org/pinecone/types" "github.com/matrix-org/pinecone/types"
@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() {
m.logger.SetOutput(BindLogger{}) m.logger.SetOutput(BindLogger{})
logrus.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{})
pineconeEventChannel := make(chan pineconeEvents.Event)
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
m.PineconeRouter.EnableHopLimiting()
m.PineconeRouter.EnableWakeupBroadcasts()
m.PineconeRouter.Subscribe(pineconeEventChannel)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil)
@ -423,6 +430,34 @@ func (m *DendriteMonolith) Start() {
m.logger.Fatal(err) m.logger.Fatal(err)
} }
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
} }
func (m *DendriteMonolith) Stop() { func (m *DendriteMonolith) Stop() {

View file

@ -10,12 +10,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -28,12 +28,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -67,8 +68,10 @@ func TestLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
RtFailedLogin: ratelimit.RtFailedLoginConfig{ RtFailedLogin: ratelimit.RtFailedLoginConfig{
Enabled: false, Enabled: false,
}, },
@ -148,8 +151,10 @@ func TestBadLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
if errRes == nil { if errRes == nil {

View file

@ -110,13 +110,19 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."), JSON: jsonerror.BadJSON("A password must be supplied."),
} }
} }
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername(err.Error()), JSON: jsonerror.InvalidUsername(err.Error()),
} }
} }
if !t.Config.Matrix.IsLocalServerName(domain) {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername("The server name is not known."),
}
}
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
localpart = strings.ToLower(localpart) localpart = strings.ToLower(localpart)
@ -129,23 +135,28 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
} }
} }
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res) err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
if !res.Exists { if !res.Exists {
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows

View file

@ -49,8 +49,10 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a
func setup() *UserInteractive { func setup() *UserInteractive {
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
return NewUserInteractive(&fakeAccountDatabase{}, cfg) return NewUserInteractive(&fakeAccountDatabase{}, cfg)
} }

View file

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"] localpart, ok := vars["localpart"]
if !ok { if !ok {
return util.JSONResponse{ return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
Password: request.Password, Password: request.Password,
LogoutDevices: true, LogoutDevices: true,
} }

View file

@ -477,7 +477,7 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -77,7 +77,7 @@ func DirectoryRoom(
// If we don't know it locally, do a federation query. // If we don't know it locally, do a federation query.
// But don't send the query to ourselves. // But don't send the query to ourselves.
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias)
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.

View file

@ -74,7 +74,7 @@ func GetPostPublicRooms(
serverName := gomatrixserverlib.ServerName(request.Server) serverName := gomatrixserverlib.ServerName(request.Server)
if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) {
res, err := federation.GetPublicRoomsFiltered( res, err := federation.GetPublicRoomsFiltered(
req.Context(), serverName, req.Context(), cfg.Matrix.ServerName, serverName,
int(request.Limit), request.Since, int(request.Limit), request.Since,
request.Filter.SearchTerms, false, request.Filter.SearchTerms, false,
"", "",

View file

@ -113,6 +113,7 @@ func completeAuth(
DeviceID: login.DeviceID, DeviceID: login.DeviceID,
AccessToken: token, AccessToken: token,
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
IPAddr: ipAddr, IPAddr: ipAddr,
UserAgent: userAgent, UserAgent: userAgent,
}, &performRes) }, &performRes)

View file

@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
device.UserDomain(),
serverName, serverName,
serverName, serverName,
nil, nil,
@ -322,7 +323,12 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns

View file

@ -40,13 +40,14 @@ func GetNotifications(
} }
var queryRes userapi.QueryNotificationsResponse var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
From: req.URL.Query().Get("from"), From: req.URL.Query().Get("from"),
Limit: int(limit), Limit: int(limit),
Only: req.URL.Query().Get("only"), Only: req.URL.Query().Get("only"),

View file

@ -61,6 +61,7 @@ func Password(
sessionID = util.RandomString(sessionIDLength) sessionID = util.RandomString(sessionIDLength)
} }
var localpart string var localpart string
var domain gomatrixserverlib.ServerName
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypePassword: case authtypes.LoginTypePassword:
// Check if the existing password is correct. // Check if the existing password is correct.
@ -73,7 +74,7 @@ func Password(
} }
// Get the local part. // Get the local part.
var err error var err error
localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err = gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -119,6 +120,7 @@ func Password(
} }
} }
localpart = res.Localpart localpart = res.Localpart
domain = res.ServerName
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
default: default:
flows := []authtypes.Flow{ flows := []authtypes.Flow{
@ -152,6 +154,7 @@ func Password(
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Password: r.NewPassword, Password: r.NewPassword,
} }
passwordRes := &api.PerformPasswordUpdateResponse{} passwordRes := &api.PerformPasswordUpdateResponse{}
@ -192,6 +195,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
SessionID: sessionId, SessionID: sessionId,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -278,7 +278,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -292,7 +292,7 @@ func updateProfile(
return jsonerror.InternalServerError(), e return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
@ -315,7 +315,7 @@ func getProfile(
} }
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {
@ -343,6 +343,7 @@ func getProfile(
func buildMembershipEvents( func buildMembershipEvents(
ctx context.Context, ctx context.Context,
device *userapi.Device,
roomIDs []string, roomIDs []string,
newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI,
evTime time.Time, rsAPI api.ClientRoomserverAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI,
@ -374,7 +375,12 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var queryRes userapi.QueryPushersResponse var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
} }
body.Localpart = localpart body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil { if err != nil {

View file

@ -123,8 +123,13 @@ func SendRedaction(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return jsonerror.InternalServerError()
}
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse var queryRes roomserverAPI.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -132,7 +137,7 @@ func SendRedaction(
} }
} }
domain := device.UserDomain() domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -214,6 +214,7 @@ type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
ServerName gomatrixserverlib.ServerName `json:"-"`
Admin bool `json:"admin"` Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -552,6 +553,12 @@ func Register(
} }
var r registerRequest var r registerRequest
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
r.ServerName = v.ServerName
} else {
r.ServerName = cfg.Matrix.ServerName
}
sessionID := gjson.GetBytes(reqBody, "auth.session").String() sessionID := gjson.GetBytes(reqBody, "auth.session").String()
if sessionID == "" { if sessionID == "" {
// Generate a new, random session ID // Generate a new, random session ID
@ -561,6 +568,7 @@ func Register(
// Some of these might end up being overwritten if the // Some of these might end up being overwritten if the
// values are specified again in the request body. // values are specified again in the request body.
r.Username = data.Username r.Username = data.Username
r.ServerName = data.ServerName
r.Password = data.Password r.Password = data.Password
r.DeviceID = data.DeviceID r.DeviceID = data.DeviceID
r.InitialDisplayName = data.InitialDisplayName r.InitialDisplayName = data.InitialDisplayName
@ -572,11 +580,13 @@ func Register(
JSON: response, JSON: response,
} }
} }
} }
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
@ -590,12 +600,15 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{} nreq := &userapi.QueryNumericLocalpartRequest{
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { ServerName: r.ServerName,
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10) r.Username = strconv.FormatInt(nres.ID, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access
@ -608,7 +621,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -621,7 +634,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -645,16 +658,25 @@ func handleGuestRegistration(
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled { registrationEnabled := !cfg.RegistrationDisabled
guestsEnabled := !cfg.GuestsDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, guestsEnabled = v.RegistrationAllowed()
}
if !registrationEnabled || !guestsEnabled {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Guest registration is disabled on %q", r.ServerName),
),
} }
} }
var res userapi.PerformAccountCreationResponse var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest, AccountType: userapi.AccountTypeGuest,
ServerName: r.ServerName,
}, &res) }, &res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -678,6 +700,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr, IPAddr: req.RemoteAddr,
@ -730,10 +753,16 @@ func handleRegistrationFlow(
) )
} }
if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { registrationEnabled := !cfg.RegistrationDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, _ = v.RegistrationAllowed()
}
if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is disabled on %q", r.ServerName),
),
} }
} }
@ -845,7 +874,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil,
) )
} }
@ -865,7 +894,7 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid,
) )
} }
@ -888,7 +917,8 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string, username string, serverName gomatrixserverlib.ServerName,
password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
@ -911,6 +941,7 @@ func completeRegistration(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
ServerName: serverName,
Password: password, Password: password,
AccountType: accType, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
@ -969,6 +1000,7 @@ func completeRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: username, Localpart: username,
ServerName: serverName,
AccessToken: token, AccessToken: token,
DeviceDisplayName: displayName, DeviceDisplayName: displayName,
DeviceID: deviceID, DeviceID: deviceID,
@ -1062,13 +1094,31 @@ func RegisterAvailable(
// Squash username to all lowercase letters // Squash username to all lowercase letters
username = strings.ToLower(username) username = strings.ToLower(username)
domain := cfg.Matrix.ServerName
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
domain = v.ServerName
}
if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil {
username, domain = u, l
}
for _, v := range cfg.Matrix.VirtualHosts {
if v.ServerName == domain && !v.AllowRegistration {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)),
),
}
}
}
if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { if err := validateUsername(username, domain); err != nil {
return *err return *err
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) userID := userutil.MakeUserID(username, domain)
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.OwnsNamespaceCoveringUserId(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
return util.JSONResponse{ return util.JSONResponse{
@ -1081,6 +1131,7 @@ func RegisterAvailable(
res := &userapi.QueryAccountAvailabilityResponse{} res := &userapi.QueryAccountAvailabilityResponse{}
err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: username, Localpart: username,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -1137,5 +1188,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil)
} }

View file

@ -159,7 +159,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI) return AdminResetPassword(req, cfg, device, userAPI)
}), }),
@ -254,7 +254,7 @@ func Setup(
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, userAPI, vars["roomIDOrAlias"], req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
@ -276,7 +276,7 @@ func Setup(
v3mux.Handle("/joined_rooms", v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI) return GetJoinedRooms(req, device, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/join", v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -290,7 +290,7 @@ func Setup(
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, userAPI, vars["roomID"], req, device, rsAPI, userAPI, vars["roomID"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -304,7 +304,7 @@ func Setup(
return LeaveRoomByID( return LeaveRoomByID(
req, device, rsAPI, vars["roomID"], req, device, rsAPI, vars["roomID"],
) )
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unpeek", v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -363,7 +363,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -374,7 +374,7 @@ func Setup(
txnID := vars["txnID"] txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -383,7 +383,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -402,7 +402,7 @@ func Setup(
eventType := strings.TrimSuffix(vars["type"], "/") eventType := strings.TrimSuffix(vars["type"], "/")
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -411,7 +411,7 @@ func Setup(
} }
eventFormat := req.URL.Query().Get("format") == "event" eventFormat := req.URL.Query().Get("format") == "event"
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -422,7 +422,7 @@ func Setup(
emptyString := "" emptyString := ""
eventType := strings.TrimSuffix(vars["eventType"], "/") eventType := strings.TrimSuffix(vars["eventType"], "/")
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
@ -433,7 +433,7 @@ func Setup(
} }
stateKey := vars["stateKey"] stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/multiroom/{dataType}", v3mux.Handle("/multiroom/{dataType}",
@ -588,7 +588,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// This is only here because sytest refers to /unstable for this endpoint // This is only here because sytest refers to /unstable for this endpoint
@ -602,7 +602,7 @@ func Setup(
} }
txnID := vars["txnID"] txnID := vars["txnID"]
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
@ -611,7 +611,7 @@ func Setup(
return *r return *r
} }
return Whoami(req, device) return Whoami(req, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
@ -843,7 +843,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
@ -884,7 +884,7 @@ func Setup(
v3mux.Handle("/thirdparty/protocols", v3mux.Handle("/thirdparty/protocols",
httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Protocols(req, asAPI, device, "") return Protocols(req, asAPI, device, "")
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/protocol/{protocolID}", v3mux.Handle("/thirdparty/protocol/{protocolID}",
@ -894,7 +894,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return Protocols(req, asAPI, device, vars["protocolID"]) return Protocols(req, asAPI, device, vars["protocolID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/user/{protocolID}", v3mux.Handle("/thirdparty/user/{protocolID}",
@ -904,13 +904,13 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) return User(req, asAPI, device, vars["protocolID"], req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/user", v3mux.Handle("/thirdparty/user",
httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return User(req, asAPI, device, "", req.URL.Query()) return User(req, asAPI, device, "", req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/location/{protocolID}", v3mux.Handle("/thirdparty/location/{protocolID}",
@ -920,13 +920,13 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) return Location(req, asAPI, device, vars["protocolID"], req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thirdparty/location", v3mux.Handle("/thirdparty/location",
httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Location(req, asAPI, device, "", req.URL.Query()) return Location(req, asAPI, device, "", req.URL.Query())
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/initialSync", v3mux.Handle("/rooms/{roomID}/initialSync",
@ -1067,7 +1067,7 @@ func Setup(
v3mux.Handle("/devices", v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device) return GetDevicesByLocalpart(req, userAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1077,7 +1077,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDeviceByID(req, userAPI, device, vars["deviceID"]) return GetDeviceByID(req, userAPI, device, vars["deviceID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1087,7 +1087,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
@ -1129,21 +1129,21 @@ func Setup(
// Stub implementations for sytest // Stub implementations for sytest
v3mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{}, "chunk": []interface{}{},
"start": "", "start": "",
"end": "", "end": "",
}} }}
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/initialSync", v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "", "end": "",
}} }}
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
@ -1182,7 +1182,7 @@ func Setup(
return *r return *r
} }
return GetCapabilities(req, rsAPI) return GetCapabilities(req, rsAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// Key Backup Versions (Metadata) // Key Backup Versions (Metadata)
@ -1363,7 +1363,7 @@ func Setup(
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceSignatures(req, keyAPI, device) return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
}) }, httputil.WithAllowGuests())
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
@ -1375,22 +1375,22 @@ func Setup(
v3mux.Handle("/keys/upload/{deviceID}", v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload", v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device) return QueryKeys(req, keyAPI, device)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/claim", v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI) return ClaimKeys(req, keyAPI)
}), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -186,6 +186,7 @@ func SendEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion), e.Headered(verRes.RoomVersion),
}, },
device.UserDomain(),
domain, domain,
domain, domain,
txnAndSessionID, txnAndSessionID,
@ -275,8 +276,14 @@ func generateSendEvent(
return nil, &resErr return nil, &resErr
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -231,6 +231,7 @@ func SendServerNotice(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion), e.Headered(roomVersion),
}, },
device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
txnAndSessionID, txnAndSessionID,
@ -286,6 +287,7 @@ func getSenderDevice(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
@ -296,6 +298,7 @@ func getSenderDevice(
avatarRes := &userapi.PerformSetAvatarURLResponse{} avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, avatarRes); err != nil { }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
@ -308,6 +311,7 @@ func getSenderDevice(
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DisplayName: cfg.Matrix.ServerNotices.DisplayName, DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil { }, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
@ -353,6 +357,7 @@ func getSenderDevice(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token, AccessToken: token,
NoDeviceListUpdate: true, NoDeviceListUpdate: true,

View file

@ -36,11 +36,17 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if !resp.Exists { if !resp.Exists {
if protocol != "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The protocol is unknown."), JSON: jsonerror.NotFound("The protocol is unknown."),
} }
} }
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
if protocol != "" { if protocol != "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -136,7 +136,7 @@ func CheckAndSave3PIDAssociation(
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -145,6 +145,7 @@ func CheckAndSave3PIDAssociation(
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
ThreePID: address, ThreePID: address,
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Medium: medium, Medium: medium,
}, &struct{}{}); err != nil { }, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -170,6 +171,7 @@ func GetAssociated3PIDs(
res := &api.QueryThreePIDsForLocalpartResponse{} res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")

View file

@ -106,7 +106,7 @@ knownUsersLoop:
continue continue
} }
// TODO: We should probably cache/store this // TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {

View file

@ -359,8 +359,13 @@ func emit3PIDInviteEvent(
return err return err
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return err
}
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
@ -371,6 +376,7 @@ func emit3PIDInviteEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(queryRes.RoomVersion), event.Headered(queryRes.RoomVersion),
}, },
device.UserDomain(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,

View file

@ -30,7 +30,9 @@ var (
// TestGoodUserID checks that correct localpart is returned for a valid user ID. // TestGoodUserID checks that correct localpart is returned for a valid user ID.
func TestGoodUserID(t *testing.T) { func TestGoodUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(goodUserID, cfg) lp, _, err := ParseUsernameParam(goodUserID, cfg)
@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) {
// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart.
func TestWithLocalpartOnly(t *testing.T) { func TestWithLocalpartOnly(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(localpart, cfg) lp, _, err := ParseUsernameParam(localpart, cfg)
@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) {
// TestIncorrectDomain checks for error when there's server name mismatch. // TestIncorrectDomain checks for error when there's server name mismatch.
func TestIncorrectDomain(t *testing.T) { func TestIncorrectDomain(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: invalidServerName, ServerName: invalidServerName,
},
} }
_, _, err := ParseUsernameParam(goodUserID, cfg) _, _, err := ParseUsernameParam(goodUserID, cfg)
@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) {
// TestBadUserID checks that ParseUsernameParam fails for invalid user ID // TestBadUserID checks that ParseUsernameParam fails for invalid user ID
func TestBadUserID(t *testing.T) { func TestBadUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
_, _, err := ParseUsernameParam(badUserID, cfg) _, _, err := ParseUsernameParam(badUserID, cfg)

View file

@ -101,9 +101,7 @@ func CreateFederationClient(
base *base.BaseDendrite, s *pineconeSessions.Sessions, base *base.BaseDendrite, s *pineconeSessions.Sessions,
) *gomatrixserverlib.FederationClient { ) *gomatrixserverlib.FederationClient {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.KeyID,
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(createTransport(s)), gomatrixserverlib.WithTransport(createTransport(s)),
) )
} }

View file

@ -37,6 +37,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
@ -51,6 +52,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -155,7 +157,12 @@ func main() {
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
pineconeEventChannel := make(chan pineconeEvents.Event)
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
pRouter.EnableHopLimiting()
pRouter.EnableWakeupBroadcasts()
pRouter.Subscribe(pineconeEventChannel)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
@ -293,5 +300,33 @@ func main() {
logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter))
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
base.WaitForShutdown() base.WaitForShutdown()
} }

View file

@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
for _, k := range p.r.Peers() { for _, k := range p.r.Peers() {
list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{}
} }
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.r.PublicKey().String()), list,
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers map[gomatrixserverlib.ServerName]struct{}, homeservers map[gomatrixserverlib.ServerName]struct{},
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient(
}, },
) )
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(tr), gomatrixserverlib.WithTransport(tr),
) )
} }

View file

@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider(
} }
func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.node.DerivedServerName()),
p.node.KnownNodes(),
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers []gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName,
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -38,6 +38,7 @@ var (
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
) )
@ -49,7 +50,7 @@ const HEAD = "HEAD"
// due to the error: // due to the error:
// When using COPY with more than one source file, the destination must be a directory and end with a / // When using COPY with more than one source file, the destination must be a directory and end with a /
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead. // We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
const Dockerfile = `FROM golang:1.18-stretch as build const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql RUN apt-get update && apt-get install -y postgresql
WORKDIR /build WORKDIR /build
@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost
EXPOSE 8008 8448 EXPOSE 8008 8448
CMD /build/run_dendrite.sh ` CMD /build/run_dendrite.sh `
const DockerfileSQLite = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build
# Copy the build context to the repo as this is the right dendrite code. This is different to the
# Complement Dockerfile which wgets a branch.
COPY . .
RUN go build ./cmd/dendrite-monolith-server
RUN go build ./cmd/generate-keys
RUN go build ./cmd/generate-config
RUN ./generate-config --ci > dendrite.yaml
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
# Make sure the SQLite databases are in a persistent location, we're already mapping
# the postgresql folder so let's just use that for simplicity
RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml
# This entry script starts postgres, waits for it to be up then starts dendrite
RUN echo '\
sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\
PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\
./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\
' > run_dendrite.sh && chmod +x run_dendrite.sh
ENV SERVER_NAME=localhost
EXPOSE 8008 8448
CMD /build/run_dendrite.sh `
func dockerfile() []byte {
if *flagSqlite {
return []byte(DockerfileSQLite)
}
return []byte(DockerfilePostgreSQL)
}
const dendriteUpgradeTestLabel = "dendrite_upgrade_test" const dendriteUpgradeTestLabel = "dendrite_upgrade_test"
// downloadArchive downloads an arbitrary github archive of the form: // downloadArchive downloads an arbitrary github archive of the form:
@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
if branchOrTagName == HEAD && *flagHead != "" { if branchOrTagName == HEAD && *flagHead != "" {
log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead)
// add top level Dockerfile // add top level Dockerfile
err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm)
if err != nil { if err != nil {
return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err)
} }
@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
// pull an archive, this contains a top-level directory which screws with the build context // pull an archive, this contains a top-level directory which screws with the build context
// which we need to fix up post download // which we need to fix up post download
u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName)
tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to download archive %s: %w", u, err) return "", fmt.Errorf("failed to download archive %s: %w", u, err)
} }
@ -367,7 +404,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
// hit /versions to check it is up // hit /versions to check it is up
var lastErr error var lastErr error
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
res, err := http.Get(versionsURL) var res *http.Response
res, err = http.Get(versionsURL)
if err != nil { if err != nil {
lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -381,18 +419,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
lastErr = nil lastErr = nil
break break
} }
if lastErr != nil {
logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{
ShowStdout: true, ShowStdout: true,
ShowStderr: true, ShowStderr: true,
Follow: true,
}) })
// ignore errors when cannot get logs, it's just for debugging anyways // ignore errors when cannot get logs, it's just for debugging anyways
if err == nil { if err == nil {
logbody, err := io.ReadAll(logs) go func() {
if err == nil { for {
log.Printf("Container logs:\n\n%s\n\n", string(logbody)) if body, err := io.ReadAll(logs); err == nil && len(body) > 0 {
log.Printf("%s: %s", version, string(body))
} else {
return
} }
} }
}()
} }
return baseURL, containerID, lastErr return baseURL, containerID, lastErr
} }

View file

@ -48,10 +48,15 @@ func main() {
panic("unexpected key block") panic("unexpected key block")
} }
serverName := gomatrixserverlib.ServerName(*requestFrom)
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName(*requestFrom), []*gomatrixserverlib.SigningIdentity{
gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), {
privateKey, ServerName: serverName,
KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]),
PrivateKey: privateKey,
},
},
) )
u, err := url.Parse(flag.Arg(0)) u, err := url.Parse(flag.Arg(0))
@ -79,6 +84,7 @@ func main() {
req := gomatrixserverlib.NewFederationRequest( req := gomatrixserverlib.NewFederationRequest(
method, method,
serverName,
gomatrixserverlib.ServerName(u.Host), gomatrixserverlib.ServerName(u.Host),
u.RequestURI(), u.RequestURI(),
) )

View file

@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config:
handle /.well-known/matrix/server { handle /.well-known/matrix/server {
header Content-Type application/json header Content-Type application/json
header Access-Control-Allow-Origin * header Access-Control-Allow-Origin *
respond `"m.server": "matrix.example.com:8448"` respond `{"m.server": "matrix.example.com:8448"}`
} }
handle /.well-known/matrix/client { handle /.well-known/matrix/client {

View file

@ -21,8 +21,8 @@ type FederationInternalAPI interface {
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -30,6 +30,12 @@ type FederationInternalAPI interface {
request *PerformBroadcastEDURequest, request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse, response *PerformBroadcastEDUResponse,
) error ) error
PerformWakeupServers(
ctx context.Context,
request *PerformWakeupServersRequest,
response *PerformWakeupServersResponse,
) error
} }
type ClientFederationAPI interface { type ClientFederationAPI interface {
@ -60,18 +66,18 @@ type RoomserverFederationAPI interface {
// containing only the server names (without information for membership events). // containing only the server names (without information for membership events).
// The response will include this server if they are joined to the room. // The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type KeyserverFederationAPI interface { type KeyserverFederationAPI interface {
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
} }
// an interface for gmsl.FederationClient - contains functions called by federationapi only. // an interface for gmsl.FederationClient - contains functions called by federationapi only.
@ -80,28 +86,28 @@ type FederationClient interface {
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations // Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.
@ -200,6 +206,7 @@ type PerformInviteResponse struct {
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"` ExcludeSelf bool `json:"exclude_self"`
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
@ -213,6 +220,13 @@ type PerformBroadcastEDURequest struct {
type PerformBroadcastEDUResponse struct { type PerformBroadcastEDUResponse struct {
} }
type PerformWakeupServersRequest struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformWakeupServersResponse struct {
}
type InputPublicKeysRequest struct { type InputPublicKeysRequest struct {
Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"`
} }

View file

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return true return true
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View file

@ -18,6 +18,10 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv"
"time"
syncAPITypes "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -38,10 +42,12 @@ type OutputRoomEventConsumer struct {
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
natsClient *nats.Conn
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
topic string topic string
topicPresence string
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -49,6 +55,7 @@ func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.FederationAPI, cfg *config.FederationAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
natsClient *nats.Conn,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
@ -57,11 +64,13 @@ func NewOutputRoomEventConsumer(
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
natsClient: natsClient,
db: store, db: store,
queues: queues, queues: queues,
rsAPI: rsAPI, rsAPI: rsAPI,
durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence),
} }
} }
@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// processMessage updates the list of currently joined hosts in the room // processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event. // and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
// Ask the roomserver and add in the rest of the results into the set. // Ask the roomserver and add in the rest of the results into the set.
@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
return err return err
} }
// If we added new hosts, inform them about our known presence events for this room
if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
membership, _ := ore.Event.Membership()
if membership == gomatrixserverlib.Join {
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts)
}
}
if oldJoinedHosts == nil { if oldJoinedHosts == nil {
// This means that there is nothing to update as this is a duplicate // This means that there is nothing to update as this is a duplicate
// message. // message.
@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
) )
} }
func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) {
joined := make([]gomatrixserverlib.ServerName, len(addedJoined))
for _, added := range addedJoined {
joined = append(joined, added.ServerName)
}
// get our locally joined users
var queryRes api.QueryMembershipsForRoomResponse
err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{
JoinedOnly: true,
LocalOnly: true,
RoomID: roomID,
}, &queryRes)
if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user")
return
}
// send every presence we know about to the remote server
content := types.Presence{}
for _, ev := range queryRes.JoinEvents {
msg := nats.NewMsg(s.topicPresence)
msg.Header.Set(jetstream.UserID, ev.Sender)
var presence *nats.Msg
presence, err = s.natsClient.RequestMsg(msg, time.Second*10)
if err != nil {
log.WithError(err).Errorf("unable to get presence")
continue
}
statusMsg := presence.Header.Get("status_msg")
e := presence.Header.Get("error")
if e != "" {
continue
}
var lastActive int
lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts"))
if err != nil {
continue
}
p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)}
content.Push = append(content.Push, types.PresenceContent{
CurrentlyActive: p.CurrentlyActive(),
LastActiveAgo: p.LastActiveAgo(),
Presence: presence.Header.Get("presence"),
StatusMsg: &statusMsg,
UserID: ev.Sender,
})
}
if len(content.Push) == 0 {
return
}
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MPresence,
Origin: string(s.cfg.Matrix.ServerName),
}
if edu.Content, err = json.Marshal(content); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON")
return
}
if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil {
log.WithError(err).Error("failed to send EDU")
}
}
// joinedHostsAtEvent works out a list of matrix servers that were joined to // joinedHostsAtEvent works out a list of matrix servers that were joined to
// the room at the event (including peeking ones) // the room at the event (including peeking ones)
// It is important to use the state at the event for sending messages because: // It is important to use the state at the event for sending messages because:

View file

@ -118,21 +118,19 @@ func NewInternalAPI(
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := base.Cfg.Global.SigningIdentities()
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, &stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ signingInfo,
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
) )
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, queues, base.ProcessContext, cfg, js, nats, queues,
federationDB, rsAPI, federationDB, rsAPI,
) )
if err = rsConsumer.Start(); err != nil { if err = rsConsumer.Start(); err != nil {

View file

@ -104,7 +104,7 @@ func TestMain(m *testing.M) {
// Create the federation client. // Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient( s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv, s.config.Matrix.SigningIdentities(),
gomatrixserverlib.WithTransport(transport), gomatrixserverlib.WithTransport(transport),
) )
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
} }
// Get the keys and JSON-ify them. // Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config) keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ") body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv
return keys, nil return keys, nil
} }
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
res.RoomVersion = r.Version res.RoomVersion = r.Version
@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName
} }
return return
} }
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
f.fedClientMutex.Lock() f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
fedCli := gomatrixserverlib.NewFederationClient( fedCli := gomatrixserverlib.NewFederationClient(
serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, cfg.Global.SigningIdentities(),
gomatrixserverlib.WithSkipVerify(true), gomatrixserverlib.WithSkipVerify(true),
) )
@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
t.Errorf("failed to create invite v2 request: %s", err) t.Errorf("failed to create invite v2 request: %s", err)
continue continue
} }
_, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq)
if err == nil { if err == nil {
t.Errorf("expected an error, got none") t.Errorf("expected an error, got none")
continue continue

View file

@ -11,13 +11,13 @@ import (
// client. // client.
func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res gomatrixserverlib.RespEventAuth, err error) { ) (res gomatrixserverlib.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespEventAuth{}, err return gomatrixserverlib.RespEventAuth{}, err
@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth(
} }
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespUserDevices{}, err return gomatrixserverlib.RespUserDevices{}, err
@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices(
} }
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespClaimKeys{}, err return gomatrixserverlib.RespClaimKeys{}, err
@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys(
} }
func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, s, keys) return a.federation.QueryKeys(ctx, origin, s, keys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespQueryKeys{}, err return gomatrixserverlib.RespQueryKeys{}, err
@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys(
} }
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill(
} }
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) { ) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespState{}, err return gomatrixserverlib.RespState{}, err
@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState(
} }
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) { ) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespStateIDs{}, err return gomatrixserverlib.RespStateIDs{}, err
@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs(
} }
func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespMissingEvents{}, err return gomatrixserverlib.RespMissingEvents{}, err
@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents(
} }
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys(
} }
func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
if err != nil { if err != nil {
return res, err return res, err
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err

View file

@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
) (err error) { ) (err error) {
dir, err := r.federation.LookupRoomAlias( dir, err := r.federation.LookupRoomAlias(
ctx, ctx,
r.cfg.Matrix.ServerName,
request.ServerName, request.ServerName,
request.RoomAlias, request.RoomAlias,
) )
@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion, supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{}, unsigned map[string]interface{},
) error { ) error {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
}
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
respMakeJoin, err := r.federation.MakeJoin( respMakeJoin, err := r.federation.MakeJoin(
ctx, ctx,
origin,
serverName, serverName,
roomID, roomID,
userID, userID,
@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Build the join event. // Build the join event.
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
context.Background(), context.Background(),
origin,
serverName, serverName,
event, event,
) )
@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
r.keyRing, r.keyRing,
event, event,
federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
) )
if err != nil { if err != nil {
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,
origin,
roomserverAPI.KindNew, roomserverAPI.KindNew,
respState, respState,
event.Headered(respMakeJoin.RoomVersion), event.Headered(respMakeJoin.RoomVersion),
@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// request. // request.
respPeek, err := r.federation.Peek( respPeek, err := r.federation.Peek(
ctx, ctx,
r.cfg.Matrix.ServerName,
serverName, serverName,
roomID, roomID,
peekID, peekID,
@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName))
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// logrus.Warnf("got respPeek %#v", respPeek) // logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view. // Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI, ctx, r.rsAPI, r.cfg.Matrix.ServerName,
roomserverAPI.KindNew, roomserverAPI.KindNew,
&respState, &respState,
respPeek.LatestEvent.Headered(respPeek.RoomVersion), respPeek.LatestEvent.Headered(respPeek.RoomVersion),
@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
// Deduplicate the server names we were provided. // Deduplicate the server names we were provided.
util.SortAndUnique(request.ServerNames) util.SortAndUnique(request.ServerNames)
@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin,
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := respMakeLeave.LeaveEvent.Build( event, err := respMakeLeave.LeaveEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeLeave.RoomVersion, respMakeLeave.RoomVersion,
@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin,
serverName, serverName,
event, event,
) )
@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender())
if err != nil {
return err
}
if request.Event.StateKey() == nil { if request.Event.StateKey() == nil {
return errors.New("invite must be a state event") return errors.New("invite must be a state event")
} }
@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq)
if err != nil { if err != nil {
return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
} }
@ -648,9 +670,23 @@ func (r *FederationInternalAPI) PerformBroadcastEDU(
return nil return nil
} }
// PerformWakeupServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) (err error) {
r.MarkServersAlive(request.ServerNames)
return nil
}
func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) {
for _, srv := range destinations { for _, srv := range destinations {
// Check the statistics cache for the blacklist status to prevent hitting
// the database unnecessarily.
if r.queues.IsServerBlacklisted(srv) {
_ = r.db.RemoveServerFromBlacklist(srv) _ = r.db.RemoveServerFromBlacklist(srv)
}
r.queues.RetryServer(srv) r.queues.RetryServer(srv)
} }
} }
@ -708,7 +744,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation api.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -738,7 +774,7 @@ func federatedAuthProvider(
// Try to retrieve the event from the server that sent us the send // Try to retrieve the event from the server that sent us the send
// join response. // join response.
tx, txerr := federation.GetEvent(ctx, server, eventID) tx, txerr := federation.GetEvent(ctx, origin, server, eventID)
if txerr != nil { if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
} }

View file

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
if err != nil { if err != nil {
return return
} }

View file

@ -23,6 +23,7 @@ const (
FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest"
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
@ -150,18 +151,32 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
) )
} }
// Handle an instruction to remove the respective servers from being blacklisted.
func (h *httpFederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers,
h.httpClient, ctx, request, response,
)
}
type getUserDevices struct { type getUserDevices struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
UserID string UserID string
} }
func (h *httpFederationInternalAPI) GetUserDevices( func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{ ctx, &getUserDevices{
S: s, S: s,
Origin: origin,
UserID: userID, UserID: userID,
}, },
) )
@ -169,16 +184,18 @@ func (h *httpFederationInternalAPI) GetUserDevices(
type claimKeys struct { type claimKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string OneTimeKeys map[string]map[string]string
} }
func (h *httpFederationInternalAPI) ClaimKeys( func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{ ctx, &claimKeys{
S: s, S: s,
Origin: origin,
OneTimeKeys: oneTimeKeys, OneTimeKeys: oneTimeKeys,
}, },
) )
@ -186,16 +203,18 @@ func (h *httpFederationInternalAPI) ClaimKeys(
type queryKeys struct { type queryKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Keys map[string][]string Keys map[string][]string
} }
func (h *httpFederationInternalAPI) QueryKeys( func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{ ctx, &queryKeys{
S: s, S: s,
Origin: origin,
Keys: keys, Keys: keys,
}, },
) )
@ -203,18 +222,20 @@ func (h *httpFederationInternalAPI) QueryKeys(
type backfill struct { type backfill struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Limit int Limit int
EventIDs []string EventIDs []string
} }
func (h *httpFederationInternalAPI) Backfill( func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{ ctx, &backfill{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Limit: limit, Limit: limit,
EventIDs: eventIDs, EventIDs: eventIDs,
@ -224,18 +245,20 @@ func (h *httpFederationInternalAPI) Backfill(
type lookupState struct { type lookupState struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupState( func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) { ) (gomatrixserverlib.RespState, error) {
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{ ctx, &lookupState{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -245,17 +268,19 @@ func (h *httpFederationInternalAPI) LookupState(
type lookupStateIDs struct { type lookupStateIDs struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) LookupStateIDs( func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) { ) (gomatrixserverlib.RespStateIDs, error) {
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{ ctx, &lookupStateIDs{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
}, },
@ -264,19 +289,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs(
type lookupMissingEvents struct { type lookupMissingEvents struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Missing gomatrixserverlib.MissingEvents Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupMissingEvents( func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{ ctx, &lookupMissingEvents{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Missing: missing, Missing: missing,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -286,16 +313,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents(
type getEvent struct { type getEvent struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEvent( func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{ ctx, &getEvent{
S: s, S: s,
Origin: origin,
EventID: eventID, EventID: eventID,
}, },
) )
@ -303,19 +332,21 @@ func (h *httpFederationInternalAPI) GetEvent(
type getEventAuth struct { type getEventAuth struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEventAuth( func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) { ) (gomatrixserverlib.RespEventAuth, error) {
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{ ctx, &getEventAuth{
S: s, S: s,
Origin: origin,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
@ -351,18 +382,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys(
type eventRelationships struct { type eventRelationships struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion RoomVer gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) MSC2836EventRelationships( func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{ ctx, &eventRelationships{
S: s, S: s,
Origin: origin,
Req: r, Req: r,
RoomVer: roomVersion, RoomVer: roomVersion,
}, },
@ -371,17 +404,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
SuggestedOnly bool SuggestedOnly bool
RoomID string RoomID string
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{ ctx, &spacesReq{
S: dst, S: dst,
Origin: origin,
SuggestedOnly: suggestedOnly, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
}, },

View file

@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
) )
internalAPIMux.Handle(
FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath, FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI( httputil.MakeInternalRPCAPI(
@ -59,7 +64,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices", "FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -70,7 +75,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys", "FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -81,7 +86,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys", "FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -92,7 +97,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIBackfill", "FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -103,7 +108,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupState", "FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -114,7 +119,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs", "FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -125,7 +130,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents", "FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -136,7 +141,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent", "FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID) res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -147,7 +152,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth", "FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -174,7 +179,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships", "FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -185,7 +190,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary", "FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),

View file

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests
@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() {
} }
} }
// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are
// pending events or if forceWakeup is true. This prevents starting the
// queue unnecessarily.
func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) {
eventsPending := func() bool {
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0
}
// NOTE : Only wakeup and notify the queue if there are pending events
// or if forceWakeup is true. Otherwise there is no reason to start the
// queue goroutine and waste resources.
if forceWakeup || eventsPending() {
logrus.Info("Starting queue due to pending events or forceWakeup")
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it // wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work. // that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() { func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep. // NOTE : Send notification before waking queue to prevent a race
oq.wakeQueueIfNeeded() // where the queue was running and stops due to a timeout in between
// checking it and sending the notification.
// Notify the queue that there are events ready to send. // Notify the queue that there are events ready to send.
select { select {
case oq.notify <- struct{}{}: case oq.notify <- struct{}{}:
default: default:
} }
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
} }
// wakeQueueIfNeeded will wake up the destination queue if it is // wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off // not already running.
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) wakeQueueIfNeeded() {
// Clear the backingOff flag and update the backoff metrics if it was set. // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) { if oq.backingOff.CompareAndSwap(true, false) {

View file

@ -15,7 +15,6 @@
package queue package queue
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync" "sync"
@ -46,7 +45,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client fedapi.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
@ -91,7 +90,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient, client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing []*gomatrixserverlib.SigningIdentity,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
disabled: disabled, disabled: disabled,
@ -101,9 +100,12 @@ func NewOutgoingQueues(
origin: origin, origin: origin,
client: client, client: client,
statistics: statistics, statistics: statistics,
signing: signing, signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{},
queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
} }
for _, identity := range signing {
queues.signing[identity.ServerName] = identity
}
// Look up which servers we have pending items for and then rehydrate those queues. // Look up which servers we have pending items for and then rehydrate those queues.
if !disabled { if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{} serverNames := map[gomatrixserverlib.ServerName]struct{}{}
@ -135,14 +137,6 @@ func NewOutgoingQueues(
return queues return queues
} }
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
// around together
type SigningInfo struct {
ServerName gomatrixserverlib.ServerName
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
}
type queuedPDU struct { type queuedPDU struct {
receipt *shared.Receipt receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent pdu *gomatrixserverlib.HeaderedEvent
@ -199,11 +193,10 @@ func (oqs *OutgoingQueues) SendEvent(
log.Trace("Federation is disabled, not sending event") log.Trace("Federation is disabled, not sending event")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -214,7 +207,9 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
@ -288,11 +283,10 @@ func (oqs *OutgoingQueues) SendEDU(
log.Trace("Federation is disabled, not sending EDU") log.Trace("Federation is disabled, not sending EDU")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -303,7 +297,9 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// There is absolutely no guarantee that the EDU will have a room_id // There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does* // field, as it is not required by the spec. However, if it *does*
@ -378,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU(
return nil return nil
} }
// IsServerBlacklisted returns whether or not the provided server is currently
// blacklisted.
func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool {
return oqs.statistics.ForServer(srv).Blacklisted()
}
// RetryServer attempts to resend events to the given server if we had given up. // RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
oqs.statistics.ForServer(srv).RemoveBlacklist()
serverStatistics := oqs.statistics.ForServer(srv)
forceWakeup := serverStatistics.Blacklisted()
serverStatistics.RemoveBlacklist()
serverStatistics.ClearBackoff()
if queue := oqs.getQueue(srv); queue != nil { if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff() queue.wakeQueueIfEventsPending(forceWakeup)
queue.wakeQueueIfNeeded()
} }
} }

View file

@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
} }
rs := &stubFederationRoomServerAPI{} rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist) stats := statistics.NewStatistics(db, failuresUntilBlacklist)
signingInfo := &SigningInfo{ signingInfo := []*gomatrixserverlib.SigningIdentity{
{
KeyID: "ed21019:auto", KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA, PrivateKey: test.PrivateKeyA,
ServerName: "localhost", ServerName: "localhost",
},
} }
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)

View file

@ -83,6 +83,7 @@ func Backfill(
"": eIDs, "": eIDs,
}, },
ServerName: request.Origin(), ServerName: request.Origin(),
VirtualHost: request.Destination(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
@ -123,7 +124,7 @@ func Backfill(
} }
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: request.Destination(),
PDUs: eventJSONs, PDUs: eventJSONs,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -140,6 +140,21 @@ func processInvite(
} }
} }
if event.StateKey() == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The invite event has no state key"),
}
}
_, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)),
}
}
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
@ -175,7 +190,7 @@ func processInvite(
// Sign the event so that other servers will know that we have received the invite. // Sign the event so that other servers will know that we have received the invite.
signedEvent := event.Sign( signedEvent := event.Sign(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
) )
// Add the invite event to the roomserver. // Add the invite event to the roomserver.

View file

@ -131,10 +131,20 @@ func MakeJoin(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion, RoomVersion: verRes.RoomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -134,28 +134,30 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, serverName)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.MessageResponse(http.StatusNotFound, err.Error())
} }
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
var identity *gomatrixserverlib.SigningIdentity
keys.ServerName = cfg.Matrix.ServerName var err error
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil {
if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil {
return nil, err
}
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = cfg.Matrix.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: { cfg.Matrix.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey), Key: gomatrixserverlib.Base64Bytes(publicKey),
}, },
} }
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
@ -165,6 +167,21 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
ExpiredTS: oldVerifyKey.ExpiredAt, ExpiredTS: oldVerifyKey.ExpiredAt,
} }
} }
} else {
if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil {
return nil, err
}
publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = virtualHost.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
virtualHost.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
// TODO: Virtual hosts probably want to be able to specify old signing
// keys too, just in case
}
toSign, err := json.Marshal(keys.ServerKeyFields) toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil { if err != nil {
@ -172,13 +189,9 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
} }
keys.Raw, err = gomatrixserverlib.SignJSON( keys.Raw, err = gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign,
) )
if err != nil { return &keys, err
return nil, err
}
return &keys, nil
} }
func NotaryKeys( func NotaryKeys(
@ -186,6 +199,14 @@ func NotaryKeys(
fsAPI federationAPI.FederationInternalAPI, fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest, req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse { ) util.JSONResponse {
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
if !cfg.Matrix.IsLocalServerName(serverName) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Server name not known"),
}
}
if req == nil { if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
@ -201,7 +222,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys { for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { if k, err := localKeys(cfg, serverName); err == nil {
keyList = append(keyList, *k) keyList = append(keyList, *k)
} else { } else {
return util.ErrorResponse(err) return util.ErrorResponse(err)

View file

@ -13,6 +13,7 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"time" "time"
@ -60,8 +61,18 @@ func MakeLeave(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -42,16 +41,9 @@ func GetProfile(
} }
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)),

View file

@ -83,7 +83,7 @@ func RoomAliasToID(
} }
} }
} else { } else {
resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:

View file

@ -74,7 +74,7 @@ func Setup(
} }
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg) return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {

View file

@ -197,12 +197,12 @@ type txnReq struct {
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface { type txnFederationClient interface {
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) )
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion), event.Headered(roomVersion),
}, },
t.Destination,
t.Origin, t.Origin,
api.DoNotSendToOtherServers, api.DoNotSendToOtherServers,
nil, nil,

View file

@ -147,7 +147,7 @@ type txnFedClient struct {
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
} }
func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) { ) {
fmt.Println("testFederationClient.LookupState", eventID) fmt.Println("testFederationClient.LookupState", eventID)
@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv
res = r res = r
return return
} }
func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID) fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID] r, ok := c.stateIDs[eventID]
if !ok { if !ok {
@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S
res = r res = r
return return
} }
func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID) fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID] r, ok := c.getEvent[eventID]
if !ok { if !ok {
@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN
res = r res = r
return return
} }
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing) return c.getMissingEvents(missing)
} }

View file

@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites(
} }
// Send all the events // Send all the events
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { if err := api.SendEvents(
req.Context(),
rsAPI,
api.KindNew,
evs,
cfg.Matrix.ServerName, // TODO: which virtual host?
"TODO",
cfg.Matrix.ServerName,
nil,
false,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()),
}
}
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey)
if err != nil { if err != nil {
@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite(
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
inviteEvent.Headered(verRes.RoomVersion), inviteEvent.Headered(verRes.RoomVersion),
}, },
request.Destination(),
request.Origin(), request.Origin(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,
@ -341,7 +360,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation federationAPI.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, cfg *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)
@ -357,7 +376,7 @@ func sendToRemoteServer(
} }
for _, server := range remoteServers { for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(ctx, server, builder) err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder)
if err == nil { if err == nil {
return return
} }

View file

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
} }
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
return
}
return return
} }
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) stmt := s.selectJoinedHostsForRoomsStmt
if excludingBlacklisted {
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -117,15 +117,17 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if excludeSelf { if excludeSelf {
for i, server := range servers { for i, server := range servers {
if d.IsLocalServerName(server) { if d.IsLocalServerName(server) {
servers = append(servers[:i], servers[i+1:]...) copy(servers[i:], servers[i+1:])
servers = servers[:len(servers)-1]
break
} }
} }
} }

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs { for i := range roomIDs {
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
query := selectJoinedHostsForRoomsSQL
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) if excludingBlacklisted {
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
}
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -58,7 +58,7 @@ type FederationJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
} }
type FederationBlacklist interface { type FederationBlacklist interface {

10
go.mod
View file

@ -7,7 +7,7 @@ require (
github.com/Masterminds/semver/v3 v3.1.1 github.com/Masterminds/semver/v3 v3.1.1
github.com/blevesearch/bleve/v2 v2.3.4 github.com/blevesearch/bleve/v2 v2.3.4
github.com/codeclysm/extract v2.2.0+incompatible github.com/codeclysm/extract v2.2.0+incompatible
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d github.com/dgraph-io/ristretto v0.1.1
github.com/docker/docker v20.10.19+incompatible github.com/docker/docker v20.10.19+incompatible
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/getsentry/sentry-go v0.14.0 github.com/getsentry/sentry-go v0.14.0
@ -22,12 +22,12 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.4 github.com/nats-io/nats-server/v2 v2.9.8
github.com/nats-io/nats.go v1.19.0 github.com/nats-io/nats.go v1.20.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79

22
go.sum
View file

@ -139,8 +139,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68=
@ -346,10 +346,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o=
github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g=
github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM=
github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -526,7 +526,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -697,6 +696,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View file

@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist")
// Returns an error if something else went wrong // Returns an error if something else went wrong
func QueryAndBuildEvent( func QueryAndBuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
if queryRes == nil { if queryRes == nil {
@ -50,24 +51,24 @@ func QueryAndBuildEvent(
// This can pass through a ErrRoomNoExists to the caller // This can pass through a ErrRoomNoExists to the caller
return nil, err return nil, err
} }
return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes)
} }
// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse
// provided. // provided.
func BuildEvent( func BuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil {
if err != nil {
return nil, err return nil, err
} }
event, err := builder.Build( event, err := builder.Build(
evTime, cfg.ServerName, cfg.KeyID, evTime, identity.ServerName, identity.KeyID,
cfg.PrivateKey, queryRes.RoomVersion, identity.PrivateKey, queryRes.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -42,10 +42,26 @@ type BasicAuth struct {
Password string `yaml:"password"` Password string `yaml:"password"`
} }
type AuthAPIOpts struct {
GuestAccessAllowed bool
}
// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify
// the user is allowed to do specific things.
type AuthAPIOption func(opts *AuthAPIOpts)
// WithAllowGuests checks that guest users have access to this endpoint
func WithAllowGuests() AuthAPIOption {
return func(opts *AuthAPIOpts) {
opts.GuestAccessAllowed = true
}
}
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI( func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI, metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse, f func(*http.Request, *userapi.Device) util.JSONResponse,
checks ...AuthAPIOption,
) http.Handler { ) http.Handler {
h := func(req *http.Request) util.JSONResponse { h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -76,6 +92,19 @@ func MakeAuthAPI(
} }
}() }()
// apply additional checks, if any
opts := AuthAPIOpts{}
for _, opt := range checks {
opt(&opts)
}
if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"),
}
}
jsonRes := f(req, device) jsonRes := f(req, device)
// do not log 4xx as errors as they are client fails, not server fails // do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 { if hub != nil && jsonRes.Code >= 500 {

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 10 VersionMinor = 10
VersionPatch = 7 VersionPatch = 8
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -35,7 +35,7 @@ type DeviceListUpdateConsumer struct {
durable string durable string
topic string topic string
updater *internal.DeviceListUpdater updater *internal.DeviceListUpdater
serverName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
@ -51,7 +51,7 @@ func NewDeviceListUpdateConsumer(
durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
updater: updater, updater: updater,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
return true return true
} else if serverName == t.serverName { } else if t.isLocalServerName(serverName) {
return true return true
} else if serverName != origin { } else if serverName != origin {
return true return true

View file

@ -37,6 +37,7 @@ type SigningKeyUpdateConsumer struct {
topic string topic string
keyAPI *internal.KeyInternalAPI keyAPI *internal.KeyInternalAPI
cfg *config.KeyServer cfg *config.KeyServer
isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
@ -53,6 +54,7 @@ func NewSigningKeyUpdateConsumer(
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
keyAPI: keyAPI, keyAPI: keyAPI,
cfg: cfg, cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
logrus.WithError(err).Error("failed to split user id") logrus.WithError(err).Error("failed to split user id")
return true return true
} else if serverName == t.cfg.Matrix.ServerName { } else if t.isLocalServerName(serverName) {
logrus.Warn("dropping device key update from ourself") logrus.Warn("dropping device key update from ourself")
return true return true
} else if serverName != origin { } else if serverName != origin {

View file

@ -47,7 +47,6 @@ var (
) )
) )
const defaultWaitTime = time.Minute
const requestTimeout = time.Second * 30 const requestTimeout = time.Second * 30
func init() { func init() {
@ -97,6 +96,7 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.KeyserverFederationAPI fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select. // block on or timeout via a select.
@ -140,6 +140,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process, process: process,
@ -149,6 +150,7 @@ func NewDeviceListUpdater(
api: api, api: api,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
thisServer: thisServer,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool), userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{}, userIDToChanMu: &sync.Mutex{},
@ -436,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
"server_name": serverName, "server_name": serverName,
"user_id": userID, "user_id": userID,
}) })
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err return time.Minute * 10, err
@ -454,7 +455,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
} else if e.Code >= 300 { } else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
return time.Hour, err return hourWaitTime, err
} }
case net.Error: case net.Error:
// Use the default waitTime, if it's a timeout. // Use the default waitTime, if it's a timeout.
@ -468,7 +469,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
// This is to avoid spamming remote servers, which may not be Matrix servers anymore. // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 { if e.Code >= 300 {
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
return time.Hour, err return hourWaitTime, err
} }
default: default:
// Something else failed // Something else failed

View file

@ -0,0 +1,22 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !vw
package internal
import "time"
const defaultWaitTime = time.Minute
const hourWaitTime = time.Hour

View file

@ -0,0 +1,25 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build vw
package internal
import "time"
// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
// results in a one-hour wait time from a previous device so the test times out. This is fine for
// production, but makes an otherwise passing test fail.
const defaultWaitTime = time.Second
const hourWaitTime = time.Second

View file

@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil) _, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient( fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, []*gomatrixserverlib.SigningIdentity{
{
ServerName: gomatrixserverlib.ServerName("example.test"),
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
PrivateKey: pkey,
},
},
) )
fedClient.Client = *gomatrixserverlib.NewClient( fedClient.Client = *gomatrixserverlib.NewClient(
gomatrixserverlib.WithTransport(&roundTripper{tripper}), gomatrixserverlib.WithTransport(&roundTripper{tripper}),
@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -33,12 +33,13 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
nested[userID] = val nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested domainToDeviceKeys[string(serverName)] = nested
} }
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys // claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local) keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
} }
} }
delete(domainToDeviceKeys, string(a.ThisServer)) delete(domainToDeviceKeys, domain)
} }
if len(domainToDeviceKeys) > 0 { if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
@ -142,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
defer wg.Done() defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{} domains := map[string]struct{}{}
for domain := range domainToDeviceKeys { for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
} }
for domain := range domainToCrossSigningKeys { for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
@ -555,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 { if len(devKeys) == 0 {
return return
} }
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil { if err == nil {
resultCh <- &queryKeysResp resultCh <- &queryKeysResp
return return
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
if serverName != a.ThisServer { if !a.Cfg.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users continue // ignore remote users
} }
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {

View file

@ -54,11 +54,11 @@ func NewInternalAPI(
} }
ap := &internal.KeyInternalAPI{ ap := &internal.KeyInternalAPI{
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, Cfg: cfg,
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {

View file

@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4" " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
@ -77,7 +77,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err return err
} }

View file

@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4" " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
db *sql.DB db *sql.DB
@ -80,7 +80,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err return err
} }

View file

@ -177,6 +177,7 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error

View file

@ -96,6 +96,7 @@ type TransactionID struct {
type InputRoomEventsRequest struct { type InputRoomEventsRequest struct {
InputRoomEvents []InputRoomEvent `json:"input_room_events"` InputRoomEvents []InputRoomEvent `json:"input_room_events"`
Asynchronous bool `json:"async"` Asynchronous bool `json:"async"`
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// InputRoomEventsResponse is a response to InputRoomEvents // InputRoomEventsResponse is a response to InputRoomEvents

View file

@ -148,6 +148,8 @@ type PerformBackfillRequest struct {
Limit int `json:"limit"` Limit int `json:"limit"`
// The server interested in the events. // The server interested in the events.
ServerName gomatrixserverlib.ServerName `json:"server_name"` ServerName gomatrixserverlib.ServerName `json:"server_name"`
// Which virtual host are we doing this for?
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order.

View file

@ -26,7 +26,7 @@ import (
func SendEvents( func SendEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
kind Kind, events []*gomatrixserverlib.HeaderedEvent, kind Kind, events []*gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, virtualHost, origin gomatrixserverlib.ServerName,
sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID,
async bool, async bool,
) error { ) error {
@ -40,14 +40,15 @@ func SendEvents(
TransactionID: txnID, TransactionID: txnID,
} }
} }
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendEventWithState writes an event with the specified kind to the roomserver // SendEventWithState writes an event with the specified kind to the roomserver
// with the state at the event as KindOutlier before it. Will not send any event that is // with the state at the event as KindOutlier before it. Will not send any event that is
// marked as `true` in haveEventIDs. // marked as `true` in haveEventIDs.
func SendEventWithState( func SendEventWithState(
ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName, kind Kind,
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
@ -85,17 +86,19 @@ func SendEventWithState(
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,
}) })
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendInputRoomEvents to the roomserver. // SendInputRoomEvents to the roomserver.
func SendInputRoomEvents( func SendInputRoomEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName,
ires []InputRoomEvent, async bool, ires []InputRoomEvent, async bool,
) error { ) error {
request := InputRoomEventsRequest{ request := InputRoomEventsRequest{
InputRoomEvents: ires, InputRoomEvents: ires,
Asynchronous: async, Asynchronous: async,
VirtualHost: virtualHost,
} }
var response InputRoomEventsResponse var response InputRoomEventsResponse
if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {

View file

@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest, request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
_, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
sender = ev.Sender() sender = ev.Sender()
} }
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender)
if err != nil {
return err
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
Sender: sender, Sender: sender,
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes)
if err != nil { if err != nil {
return err return err
} }
err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil { if err != nil {
return err return err
} }

View file

@ -89,7 +89,7 @@ func NewRoomserverAPI(
Queryer: &query.Queryer{ Queryer: &query.Queryer{
DB: roomserverDB, DB: roomserverDB,
Cache: base.Caches, Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName, IsLocalServerName: base.Cfg.Global.IsLocalServerName,
ServerACLs: serverACLs, ServerACLs: serverACLs,
}, },
// perform-er structs get initialised when we have a federation sender to use // perform-er structs get initialised when we have a federation sender to use
@ -127,7 +127,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Joiner = &perform.Joiner{ r.Joiner = &perform.Joiner{
ServerName: r.Cfg.Matrix.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -163,7 +162,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
DB: r.DB, DB: r.DB,
} }
r.Backfiller = &perform.Backfiller{ r.Backfiller = &perform.Backfiller{
ServerName: r.ServerName, IsLocalServerName: r.Cfg.Matrix.IsLocalServerName,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
KeyRing: r.KeyRing, KeyRing: r.KeyRing,

View file

@ -0,0 +1,56 @@
package helpers
import (
"context"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/roomserver/storage"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func TestIsInvitePendingWithoutNID(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat))
_ = bob
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
_, db, close := mustCreateDatabase(t, dbType)
defer close()
// store all events
var authNIDs []types.EventNID
for _, x := range room.Events() {
evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false)
assert.NoError(t, err)
authNIDs = append(authNIDs, evNID)
}
// Alice should have no pending invites and should have a NID
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID)
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
// Bob should have no pending invites and receive a new NID
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID)
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
})
}

View file

@ -278,7 +278,11 @@ func (w *worker) _next() {
// a string, because we might want to return that to the caller if // a string, because we might want to return that to the caller if
// it was a synchronous request. // it was a synchronous request.
var errString string var errString string
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { if err = w.r.processRoomEvent(
w.r.ProcessContext.Context(),
gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")),
&inputRoomEvent,
); err != nil {
switch err.(type) { switch err.(type) {
case types.RejectedError: case types.RejectedError:
// Don't send events that were rejected to Sentry // Don't send events that were rejected to Sentry
@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents(
if replyTo != "" { if replyTo != "" {
msg.Header.Set("sync", replyTo) msg.Header.Set("sync", replyTo)
} }
msg.Header.Set("virtual_host", string(request.VirtualHost))
msg.Data, err = json.Marshal(e) msg.Data, err = json.Marshal(e)
if err != nil { if err != nil {
return nil, fmt.Errorf("json.Marshal: %w", err) return nil, fmt.Errorf("json.Marshal: %w", err)

View file

@ -23,6 +23,8 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/tidwall/gjson"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -67,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
virtualHost gomatrixserverlib.ServerName,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) error { ) error {
select { select {
@ -164,6 +167,7 @@ func (r *Inputer) processRoomEvent(
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
ExcludeBlacklisted: true,
} }
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
@ -198,7 +202,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -264,6 +268,7 @@ func (r *Inputer) processRoomEvent(
if len(serverRes.ServerNames) > 0 { if len(serverRes.ServerNames) > 0 {
missingState := missingStateReq{ missingState := missingStateReq{
origin: input.Origin, origin: input.Origin,
virtualHost: virtualHost,
inputer: r, inputer: r,
db: r.DB, db: r.DB,
roomInfo: roomInfo, roomInfo: roomInfo,
@ -409,6 +414,13 @@ func (r *Inputer) processRoomEvent(
} }
} }
// Handle remote room upgrades, e.g. remove published room
if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil {
return fmt.Errorf("failed to handle remote room upgrade: %w", err)
}
}
// processing this event resulted in an event (which may not be the one we're processing) // processing this event resulted in an event (which may not be the one we're processing)
// being redacted. We are guaranteed to have both sides (the redaction/redacted event), // being redacted. We are guaranteed to have both sides (the redaction/redacted event),
// so notify downstream components to redact this event - they should have it if they've // so notify downstream components to redact this event - they should have it if they've
@ -434,6 +446,13 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
// handleRemoteRoomUpgrade updates published rooms and room aliases
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error {
oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
}
// processStateBefore works out what the state is before the event and // processStateBefore works out what the state is before the event and
// then checks the event auths against the state at the time. It also // then checks the event auths against the state at the time. It also
// tries to determine what the history visibility was of the event at // tries to determine what the history visibility was of the event at
@ -539,6 +558,7 @@ func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
logger *logrus.Entry, logger *logrus.Entry,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
virtualHost gomatrixserverlib.ServerName,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
known map[string]*types.Event, known map[string]*types.Event,
@ -589,7 +609,7 @@ func (r *Inputer) fetchAuthEvents(
// Request the entire auth chain for the event in question. This should // Request the entire auth chain for the event in question. This should
// contain all of the auth events — including ones that we already know — // contain all of the auth events — including ones that we already know —
// so we'll need to filter through those in the next section. // so we'll need to filter through those in the next section.
res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID())
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err)
continue continue

View file

@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
type missingStateReq struct { type missingStateReq struct {
log *logrus.Entry log *logrus.Entry
virtualHost gomatrixserverlib.ServerName
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db storage.Database
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState(
}) })
} }
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil {
if _, ok := err.(types.RejectedError); !ok { if _, ok := err.(types.RejectedError); !ok {
return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err)
} }
@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
for _, server := range t.servers { for _, server := range t.servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{
Limit: 20, Limit: 20,
// The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events.
EarliestEvents: latestEvents, EarliestEvents: latestEvents,
@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState")
defer span.Finish() defer span.Finish()
state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5)
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20)
stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID)
reqcancel() reqcancel()
if err == nil { if err == nil {
break break
@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, cancel := context.WithTimeout(ctx, time.Second*30) reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {
t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID")
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {

View file

@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
continue
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState(
req *api.PerformAdminDownloadStateRequest, req *api.PerformAdminDownloadStateRequest,
res *api.PerformAdminDownloadStateResponse, res *api.PerformAdminDownloadStateResponse,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err),
}
return nil
}
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, fwdExtremity := range fwdExtremities { for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.RespState var state gomatrixserverlib.RespState
state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState(
Depth: depth, Depth: depth,
} }
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -37,7 +37,7 @@ import (
const maxBackfillServers = 5 const maxBackfillServers = 5
type Backfiller struct { type Backfiller struct {
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
DB storage.Database DB storage.Database
FSAPI federationAPI.RoomserverFederationAPI FSAPI federationAPI.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill(
// if we are requesting the backfill then we need to do a federation hit // if we are requesting the backfill then we need to do a federation hit
// TODO: we could be more sensible and fetch as many events we already have then request the rest // TODO: we could be more sensible and fetch as many events we already have then request the rest
// which is what the syncapi does already. // which is what the syncapi does already.
if request.ServerName == r.ServerName { if r.IsLocalServerName(request.ServerName) {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
} }
// someone else is requesting the backfill, try to service their request. // someone else is requesting the backfill, try to service their request.
@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
} }
requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
// Request 100 items regardless of what the query asks for. // Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this. // We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
// (so we don't need to hit /state_ids which the test has no listener for) // (so we don't need to hit /state_ids which the test has no listener for)
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
return err return err
} }
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil {
// attempt to fetch the missing events // attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost)
// try again // try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true)
if err != nil { if err != nil {
@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
// best effort. // best effort.
func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
backfillRequester *backfillRequester, stateIDs []string) { backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) {
servers := backfillRequester.servers servers := backfillRequester.servers
@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue // already found continue // already found
} }
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
res, err := r.FSAPI.GetEvent(ctx, srv, id) res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to get event from server") logger.WithError(err).Warn("failed to get event from server")
continue continue
@ -243,7 +245,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
type backfillRequester struct { type backfillRequester struct {
db storage.Database db storage.Database
fsAPI federationAPI.RoomserverFederationAPI fsAPI federationAPI.RoomserverFederationAPI
thisServer gomatrixserverlib.ServerName virtualHost gomatrixserverlib.ServerName
isLocalServerName func(gomatrixserverlib.ServerName) bool
preferServer map[gomatrixserverlib.ServerName]bool preferServer map[gomatrixserverlib.ServerName]bool
bwExtrems map[string][]string bwExtrems map[string][]string
@ -255,7 +258,9 @@ type backfillRequester struct {
} }
func newBackfillRequester( func newBackfillRequester(
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
virtualHost gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName,
) *backfillRequester { ) *backfillRequester {
preferServer := make(map[gomatrixserverlib.ServerName]bool) preferServer := make(map[gomatrixserverlib.ServerName]bool)
@ -265,7 +270,8 @@ func newBackfillRequester(
return &backfillRequester{ return &backfillRequester{
db: db, db: db,
fsAPI: fsAPI, fsAPI: fsAPI,
thisServer: thisServer, virtualHost: virtualHost,
isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]*gomatrixserverlib.Event), eventIDMap: make(map[string]*gomatrixserverlib.Event),
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
@ -450,7 +456,7 @@ FindSuccessor:
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost)
b.historyVisiblity = visibility b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@ -477,7 +483,7 @@ FindSuccessor:
} }
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {
if server == b.thisServer { if b.isLocalServerName(server) {
continue continue
} }
if b.preferServer[server] { // insert at the front if b.preferServer[server] { // insert at the front
@ -496,10 +502,10 @@ FindSuccessor:
// Backfill performs a backfill request to the given server. // Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string,
limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) {
tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs)
return tx, err return tx, err
} }

View file

@ -39,7 +39,6 @@ import (
) )
type Joiner struct { type Joiner struct {
ServerName gomatrixserverlib.ServerName
Cfg *config.RoomServer Cfg *config.RoomServer
FSAPI fsAPI.RoomserverFederationAPI FSAPI fsAPI.RoomserverFederationAPI
RSAPI rsAPI.RoomserverInternalAPI RSAPI rsAPI.RoomserverInternalAPI
@ -197,7 +196,7 @@ func (r *Joiner) performJoinRoomByID(
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorBadRequest, Code: rsAPI.PerformErrorBadRequest,
@ -283,7 +282,7 @@ func (r *Joiner) performJoinRoomByID(
// locally on the homeserver. // locally on the homeserver.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb)
switch err { switch err {
case nil: case nil:
@ -410,7 +409,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
} }
func buildEvent( func buildEvent(
ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, ctx context.Context, db storage.Database, cfg *config.Global,
senderDomain gomatrixserverlib.ServerName,
builder *gomatrixserverlib.EventBuilder,
) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -438,7 +439,12 @@ func buildEvent(
} }
} }
ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) identity, err := cfg.SigningIdentityFor(senderDomain)
if err != nil {
return nil, nil, err
}
ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

Some files were not shown because too many files have changed in this diff Show more