Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/guestaccess

This commit is contained in:
Till Faelligen 2022-11-28 07:52:00 +01:00
commit 4d4877af5a
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
151 changed files with 2505 additions and 1285 deletions

View file

@ -26,22 +26,14 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.18 go-version: 1.18
cache: true
- uses: actions/cache@v2
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-wasm
- name: Install Node - name: Install Node
uses: actions/setup-node@v2 uses: actions/setup-node@v2
with: with:
node-version: 14 node-version: 14
- uses: actions/cache@v2 - uses: actions/cache@v3
with: with:
path: ~/.npm path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
@ -109,19 +101,12 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
cache: true
- name: Set up gotestfmt - name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2 uses: gotesttools/gotestfmt-action@v2
with: with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting. # Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test -json -v ./... 2>&1 | gotestfmt - run: go test -json -v ./... 2>&1 | gotestfmt
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
@ -146,17 +131,17 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
~/go/pkg/mod ~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies x86
if: ${{ matrix.goarch == '386' }}
run: sudo apt update && sudo apt-get install -y gcc-multilib
- env: - env:
GOOS: ${{ matrix.goos }} GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
@ -180,16 +165,16 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
~/go/pkg/mod ~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- env: - env:
GOOS: ${{ matrix.goos }} GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
@ -221,18 +206,13 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 cache: true
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade - name: Test upgrade (PostgreSQL)
run: ./dendrite-upgrade-tests --head . run: ./dendrite-upgrade-tests --head .
- name: Test upgrade (SQLite)
run: ./dendrite-upgrade-tests --sqlite --head .
# run database upgrade tests, skipping over one version # run database upgrade tests, skipping over one version
upgrade_test_direct: upgrade_test_direct:
@ -246,17 +226,12 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 cache: true
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade - name: Test upgrade (PostgreSQL)
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
- name: Test upgrade (SQLite)
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
# run Sytest in different variations # run Sytest in different variations
@ -291,6 +266,8 @@ jobs:
image: matrixdotorg/sytest-dendrite:latest image: matrixdotorg/sytest-dendrite:latest
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build
- /root/.cache/go-mod:/gopath/pkg/mod
env: env:
POSTGRES: ${{ matrix.postgres && 1}} POSTGRES: ${{ matrix.postgres && 1}}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
@ -298,6 +275,14 @@ jobs:
CGO_ENABLED: ${{ matrix.cgo && 1 }} CGO_ENABLED: ${{ matrix.cgo && 1 }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
/gopath/pkg/mod
key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-sytest-
- name: Run Sytest - name: Run Sytest
run: /bootstrap.sh dendrite run: /bootstrap.sh dendrite
working-directory: /src working-directory: /src
@ -332,12 +317,14 @@ jobs:
matrix: matrix:
include: include:
- label: SQLite native - label: SQLite native
cgo: 0
- label: SQLite Cgo - label: SQLite Cgo
cgo: 1 cgo: 1
- label: SQLite native, full HTTP APIs - label: SQLite native, full HTTP APIs
api: full-http api: full-http
cgo: 0
- label: SQLite Cgo, full HTTP APIs - label: SQLite Cgo, full HTTP APIs
api: full-http api: full-http
@ -345,10 +332,12 @@ jobs:
- label: PostgreSQL - label: PostgreSQL
postgres: Postgres postgres: Postgres
cgo: 0
- label: PostgreSQL, full HTTP APIs - label: PostgreSQL, full HTTP APIs
postgres: Postgres postgres: Postgres
api: full-http api: full-http
cgo: 0
steps: steps:
# Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement.
# See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path
@ -356,14 +345,12 @@ jobs:
run: | run: |
echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH
echo "~/go/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH
- name: "Install Complement Dependencies" - name: "Install Complement Dependencies"
# We don't need to install Go because it is included on the Ubuntu 20.04 image: # We don't need to install Go because it is included on the Ubuntu 20.04 image:
# See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64
run: | run: |
sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev
go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest
- name: Run actions/checkout@v3 for dendrite - name: Run actions/checkout@v3 for dendrite
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -389,12 +376,10 @@ jobs:
if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
continue continue
fi fi
(wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
done done
# Build initial Dendrite image # Build initial Dendrite image
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
working-directory: dendrite working-directory: dendrite
env: env:
DOCKER_BUILDKIT: 1 DOCKER_BUILDKIT: 1
@ -406,9 +391,8 @@ jobs:
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
env: env:
COMPLEMENT_BASE_IMAGE: complement-dendrite:latest COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }}
API: ${{ matrix.api && 1 }} API: ${{ matrix.api && 1 }}
CGO_ENABLED: ${{ matrix.cgo && 1 }}
working-directory: complement working-directory: complement
integration-tests-done: integration-tests-done:

View file

@ -45,6 +45,11 @@ jobs:
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Set up gotestfmt
uses: gotesttools/gotestfmt-action@v2
with:
# Optional: pass GITHUB_TOKEN to avoid rate limiting.
token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/cache@v3 - uses: actions/cache@v3
with: with:
path: | path: |
@ -53,12 +58,14 @@ jobs:
key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-race- ${{ runner.os }}-go${{ matrix.go }}-test-race-
- run: go test -race ./... - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite POSTGRES_DB: dendrite
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
initial-tests-done: initial-tests-done:

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out. // events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices { for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount( func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI, userAPI userapi.AppserviceUserAPI,
as config.ApplicationService, as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken, AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart, DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart,

View file

@ -40,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
@ -58,6 +59,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/matrix-org/pinecone/types" "github.com/matrix-org/pinecone/types"
@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() {
m.logger.SetOutput(BindLogger{}) m.logger.SetOutput(BindLogger{})
logrus.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{})
pineconeEventChannel := make(chan pineconeEvents.Event)
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
m.PineconeRouter.EnableHopLimiting()
m.PineconeRouter.EnableWakeupBroadcasts()
m.PineconeRouter.Subscribe(pineconeEventChannel)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil)
@ -423,6 +430,34 @@ func (m *DendriteMonolith) Start() {
m.logger.Fatal(err) m.logger.Fatal(err)
} }
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
} }
func (m *DendriteMonolith) Stop() { func (m *DendriteMonolith) Stop() {

View file

@ -10,12 +10,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -28,12 +28,13 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
ARG CGO
RUN --mount=target=. \ RUN --mount=target=. \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -66,8 +67,10 @@ func TestLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if err != nil { if err != nil {
@ -144,8 +147,10 @@ func TestBadLoginFromJSONReader(t *testing.T) {
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if errRes == nil { if errRes == nil {

View file

@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := strings.ToLower(r.Username()) username := r.Username()
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."), JSON: jsonerror.BadJSON("A password must be supplied."),
} }
} }
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername(err.Error()), JSON: jsonerror.InvalidUsername(err.Error()),
} }
} }
if !t.Config.Matrix.IsLocalServerName(domain) {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername("The server name is not known."),
}
}
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: strings.ToLower(localpart),
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
if !res.Exists { if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows

View file

@ -47,8 +47,10 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a
func setup() *UserInteractive { func setup() *UserInteractive {
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
}, },
},
} }
return NewUserInteractive(&fakeAccountDatabase{}, cfg) return NewUserInteractive(&fakeAccountDatabase{}, cfg)
} }

View file

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"] localpart, ok := vars["localpart"]
if !ok { if !ok {
return util.JSONResponse{ return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
Password: request.Password, Password: request.Password,
LogoutDevices: true, LogoutDevices: true,
} }

View file

@ -477,7 +477,7 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -77,7 +77,7 @@ func DirectoryRoom(
// If we don't know it locally, do a federation query. // If we don't know it locally, do a federation query.
// But don't send the query to ourselves. // But don't send the query to ourselves.
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias)
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.

View file

@ -74,7 +74,7 @@ func GetPostPublicRooms(
serverName := gomatrixserverlib.ServerName(request.Server) serverName := gomatrixserverlib.ServerName(request.Server)
if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) {
res, err := federation.GetPublicRoomsFiltered( res, err := federation.GetPublicRoomsFiltered(
req.Context(), serverName, req.Context(), cfg.Matrix.ServerName, serverName,
int(request.Limit), request.Since, int(request.Limit), request.Since,
request.Filter.SearchTerms, false, request.Filter.SearchTerms, false,
"", "",

View file

@ -100,6 +100,7 @@ func completeAuth(
DeviceID: login.DeviceID, DeviceID: login.DeviceID,
AccessToken: token, AccessToken: token,
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
IPAddr: ipAddr, IPAddr: ipAddr,
UserAgent: userAgent, UserAgent: userAgent,
}, &performRes) }, &performRes)

View file

@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
device.UserDomain(),
serverName, serverName,
serverName, serverName,
nil, nil,
@ -322,7 +323,12 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns

View file

@ -40,13 +40,14 @@ func GetNotifications(
} }
var queryRes userapi.QueryNotificationsResponse var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
From: req.URL.Query().Get("from"), From: req.URL.Query().Get("from"),
Limit: int(limit), Limit: int(limit),
Only: req.URL.Query().Get("only"), Only: req.URL.Query().Get("only"),

View file

@ -86,7 +86,7 @@ func Password(
} }
// Get the local part. // Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -95,6 +95,7 @@ func Password(
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Password: r.NewPassword, Password: r.NewPassword,
} }
passwordRes := &api.PerformPasswordUpdateResponse{} passwordRes := &api.PerformPasswordUpdateResponse{}
@ -123,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID, SessionID: device.SessionID,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -284,7 +284,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -298,7 +298,7 @@ func updateProfile(
return jsonerror.InternalServerError(), e return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
@ -321,7 +321,7 @@ func getProfile(
} }
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {
@ -349,6 +349,7 @@ func getProfile(
func buildMembershipEvents( func buildMembershipEvents(
ctx context.Context, ctx context.Context,
device *userapi.Device,
roomIDs []string, roomIDs []string,
newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI,
evTime time.Time, rsAPI api.ClientRoomserverAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI,
@ -380,7 +381,12 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var queryRes userapi.QueryPushersResponse var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
} }
body.Localpart = localpart body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil { if err != nil {

View file

@ -123,8 +123,13 @@ func SendRedaction(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return jsonerror.InternalServerError()
}
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse var queryRes roomserverAPI.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -132,7 +137,7 @@ func SendRedaction(
} }
} }
domain := device.UserDomain() domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -213,6 +213,7 @@ type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
ServerName gomatrixserverlib.ServerName `json:"-"`
Admin bool `json:"admin"` Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -550,6 +551,12 @@ func Register(
} }
var r registerRequest var r registerRequest
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
r.ServerName = v.ServerName
} else {
r.ServerName = cfg.Matrix.ServerName
}
sessionID := gjson.GetBytes(reqBody, "auth.session").String() sessionID := gjson.GetBytes(reqBody, "auth.session").String()
if sessionID == "" { if sessionID == "" {
// Generate a new, random session ID // Generate a new, random session ID
@ -559,6 +566,7 @@ func Register(
// Some of these might end up being overwritten if the // Some of these might end up being overwritten if the
// values are specified again in the request body. // values are specified again in the request body.
r.Username = data.Username r.Username = data.Username
r.ServerName = data.ServerName
r.Password = data.Password r.Password = data.Password
r.DeviceID = data.DeviceID r.DeviceID = data.DeviceID
r.InitialDisplayName = data.InitialDisplayName r.InitialDisplayName = data.InitialDisplayName
@ -570,11 +578,13 @@ func Register(
JSON: response, JSON: response,
} }
} }
} }
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
@ -588,12 +598,15 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{} nreq := &userapi.QueryNumericLocalpartRequest{
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { ServerName: r.ServerName,
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10) r.Username = strconv.FormatInt(nres.ID, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access
@ -606,7 +619,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -619,7 +632,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -643,16 +656,25 @@ func handleGuestRegistration(
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled { registrationEnabled := !cfg.RegistrationDisabled
guestsEnabled := !cfg.GuestsDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, guestsEnabled = v.RegistrationAllowed()
}
if !registrationEnabled || !guestsEnabled {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Guest registration is disabled on %q", r.ServerName),
),
} }
} }
var res userapi.PerformAccountCreationResponse var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest, AccountType: userapi.AccountTypeGuest,
ServerName: r.ServerName,
}, &res) }, &res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -676,6 +698,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr, IPAddr: req.RemoteAddr,
@ -728,10 +751,16 @@ func handleRegistrationFlow(
) )
} }
if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { registrationEnabled := !cfg.RegistrationDisabled
if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil {
registrationEnabled, _ = v.RegistrationAllowed()
}
if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Registration is disabled"), JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is disabled on %q", r.ServerName),
),
} }
} }
@ -819,8 +848,9 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
userapi.AccountTypeAppService,
) )
} }
@ -838,8 +868,9 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
userapi.AccountTypeUser,
) )
} }
sessions.addParams(sessionID, r) sessions.addParams(sessionID, r)
@ -861,7 +892,8 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string, username string, serverName gomatrixserverlib.ServerName,
password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
@ -883,6 +915,7 @@ func completeRegistration(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
ServerName: serverName,
Password: password, Password: password,
AccountType: accType, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
@ -926,6 +959,7 @@ func completeRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: username, Localpart: username,
ServerName: serverName,
AccessToken: token, AccessToken: token,
DeviceDisplayName: displayName, DeviceDisplayName: displayName,
DeviceID: deviceID, DeviceID: deviceID,
@ -1019,13 +1053,31 @@ func RegisterAvailable(
// Squash username to all lowercase letters // Squash username to all lowercase letters
username = strings.ToLower(username) username = strings.ToLower(username)
domain := cfg.Matrix.ServerName
host := gomatrixserverlib.ServerName(req.Host)
if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil {
domain = v.ServerName
}
if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil {
username, domain = u, l
}
for _, v := range cfg.Matrix.VirtualHosts {
if v.ServerName == domain && !v.AllowRegistration {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)),
),
}
}
}
if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { if err := validateUsername(username, domain); err != nil {
return *err return *err
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) userID := userutil.MakeUserID(username, domain)
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.OwnsNamespaceCoveringUserId(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
return util.JSONResponse{ return util.JSONResponse{
@ -1038,6 +1090,7 @@ func RegisterAvailable(
res := &userapi.QueryAccountAvailabilityResponse{} res := &userapi.QueryAccountAvailabilityResponse{}
err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: username, Localpart: username,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -1094,5 +1147,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -157,7 +157,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI) return AdminResetPassword(req, cfg, device, userAPI)
}), }),

View file

@ -186,6 +186,7 @@ func SendEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion), e.Headered(verRes.RoomVersion),
}, },
device.UserDomain(),
domain, domain,
domain, domain,
txnAndSessionID, txnAndSessionID,
@ -275,8 +276,14 @@ func generateSendEvent(
return nil, &resErr return nil, &resErr
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -231,6 +231,7 @@ func SendServerNotice(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion), e.Headered(roomVersion),
}, },
device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
txnAndSessionID, txnAndSessionID,
@ -286,6 +287,7 @@ func getSenderDevice(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
@ -296,6 +298,7 @@ func getSenderDevice(
avatarRes := &userapi.PerformSetAvatarURLResponse{} avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, avatarRes); err != nil { }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
@ -308,6 +311,7 @@ func getSenderDevice(
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DisplayName: cfg.Matrix.ServerNotices.DisplayName, DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil { }, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
@ -353,6 +357,7 @@ func getSenderDevice(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token, AccessToken: token,
NoDeviceListUpdate: true, NoDeviceListUpdate: true,

View file

@ -136,7 +136,7 @@ func CheckAndSave3PIDAssociation(
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -145,6 +145,7 @@ func CheckAndSave3PIDAssociation(
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
ThreePID: address, ThreePID: address,
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Medium: medium, Medium: medium,
}, &struct{}{}); err != nil { }, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -170,6 +171,7 @@ func GetAssociated3PIDs(
res := &api.QueryThreePIDsForLocalpartResponse{} res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")

View file

@ -106,7 +106,7 @@ knownUsersLoop:
continue continue
} }
// TODO: We should probably cache/store this // TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {

View file

@ -359,8 +359,13 @@ func emit3PIDInviteEvent(
return err return err
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return err
}
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
@ -371,6 +376,7 @@ func emit3PIDInviteEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(queryRes.RoomVersion), event.Headered(queryRes.RoomVersion),
}, },
device.UserDomain(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,

View file

@ -30,7 +30,9 @@ var (
// TestGoodUserID checks that correct localpart is returned for a valid user ID. // TestGoodUserID checks that correct localpart is returned for a valid user ID.
func TestGoodUserID(t *testing.T) { func TestGoodUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(goodUserID, cfg) lp, _, err := ParseUsernameParam(goodUserID, cfg)
@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) {
// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart.
func TestWithLocalpartOnly(t *testing.T) { func TestWithLocalpartOnly(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
lp, _, err := ParseUsernameParam(localpart, cfg) lp, _, err := ParseUsernameParam(localpart, cfg)
@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) {
// TestIncorrectDomain checks for error when there's server name mismatch. // TestIncorrectDomain checks for error when there's server name mismatch.
func TestIncorrectDomain(t *testing.T) { func TestIncorrectDomain(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: invalidServerName, ServerName: invalidServerName,
},
} }
_, _, err := ParseUsernameParam(goodUserID, cfg) _, _, err := ParseUsernameParam(goodUserID, cfg)
@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) {
// TestBadUserID checks that ParseUsernameParam fails for invalid user ID // TestBadUserID checks that ParseUsernameParam fails for invalid user ID
func TestBadUserID(t *testing.T) { func TestBadUserID(t *testing.T) {
cfg := &config.Global{ cfg := &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: serverName, ServerName: serverName,
},
} }
_, _, err := ParseUsernameParam(badUserID, cfg) _, _, err := ParseUsernameParam(badUserID, cfg)

View file

@ -101,9 +101,7 @@ func CreateFederationClient(
base *base.BaseDendrite, s *pineconeSessions.Sessions, base *base.BaseDendrite, s *pineconeSessions.Sessions,
) *gomatrixserverlib.FederationClient { ) *gomatrixserverlib.FederationClient {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.KeyID,
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(createTransport(s)), gomatrixserverlib.WithTransport(createTransport(s)),
) )
} }

View file

@ -37,6 +37,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
@ -51,6 +52,7 @@ import (
pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeEvents "github.com/matrix-org/pinecone/router/events"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -155,7 +157,12 @@ func main() {
base := base.NewBaseDendrite(cfg, "Monolith") base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
pineconeEventChannel := make(chan pineconeEvents.Event)
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk)
pRouter.EnableHopLimiting()
pRouter.EnableWakeupBroadcasts()
pRouter.Subscribe(pineconeEventChannel)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
@ -293,5 +300,33 @@ func main() {
logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter))
}() }()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
}(pineconeEventChannel)
base.WaitForShutdown() base.WaitForShutdown()
} }

View file

@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
for _, k := range p.r.Peers() { for _, k := range p.r.Peers() {
list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{}
} }
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.r.PublicKey().String()), list,
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers map[gomatrixserverlib.ServerName]struct{}, homeservers map[gomatrixserverlib.ServerName]struct{},
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient(
}, },
) )
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(tr), gomatrixserverlib.WithTransport(tr),
) )
} }

View file

@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider(
} }
func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.node.DerivedServerName()),
p.node.KnownNodes(),
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers []gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName,
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -38,6 +38,7 @@ var (
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
) )
@ -49,7 +50,7 @@ const HEAD = "HEAD"
// due to the error: // due to the error:
// When using COPY with more than one source file, the destination must be a directory and end with a / // When using COPY with more than one source file, the destination must be a directory and end with a /
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead. // We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
const Dockerfile = `FROM golang:1.18-stretch as build const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql RUN apt-get update && apt-get install -y postgresql
WORKDIR /build WORKDIR /build
@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost
EXPOSE 8008 8448 EXPOSE 8008 8448
CMD /build/run_dendrite.sh ` CMD /build/run_dendrite.sh `
const DockerfileSQLite = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build
# Copy the build context to the repo as this is the right dendrite code. This is different to the
# Complement Dockerfile which wgets a branch.
COPY . .
RUN go build ./cmd/dendrite-monolith-server
RUN go build ./cmd/generate-keys
RUN go build ./cmd/generate-config
RUN ./generate-config --ci > dendrite.yaml
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
# Make sure the SQLite databases are in a persistent location, we're already mapping
# the postgresql folder so let's just use that for simplicity
RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml
# This entry script starts postgres, waits for it to be up then starts dendrite
RUN echo '\
sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\
PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\
./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\
' > run_dendrite.sh && chmod +x run_dendrite.sh
ENV SERVER_NAME=localhost
EXPOSE 8008 8448
CMD /build/run_dendrite.sh `
func dockerfile() []byte {
if *flagSqlite {
return []byte(DockerfileSQLite)
}
return []byte(DockerfilePostgreSQL)
}
const dendriteUpgradeTestLabel = "dendrite_upgrade_test" const dendriteUpgradeTestLabel = "dendrite_upgrade_test"
// downloadArchive downloads an arbitrary github archive of the form: // downloadArchive downloads an arbitrary github archive of the form:
@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
if branchOrTagName == HEAD && *flagHead != "" { if branchOrTagName == HEAD && *flagHead != "" {
log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead)
// add top level Dockerfile // add top level Dockerfile
err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm)
if err != nil { if err != nil {
return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err)
} }
@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
// pull an archive, this contains a top-level directory which screws with the build context // pull an archive, this contains a top-level directory which screws with the build context
// which we need to fix up post download // which we need to fix up post download
u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName)
tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to download archive %s: %w", u, err) return "", fmt.Errorf("failed to download archive %s: %w", u, err)
} }

View file

@ -48,10 +48,15 @@ func main() {
panic("unexpected key block") panic("unexpected key block")
} }
serverName := gomatrixserverlib.ServerName(*requestFrom)
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName(*requestFrom), []*gomatrixserverlib.SigningIdentity{
gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), {
privateKey, ServerName: serverName,
KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]),
PrivateKey: privateKey,
},
},
) )
u, err := url.Parse(flag.Arg(0)) u, err := url.Parse(flag.Arg(0))
@ -79,6 +84,7 @@ func main() {
req := gomatrixserverlib.NewFederationRequest( req := gomatrixserverlib.NewFederationRequest(
method, method,
serverName,
gomatrixserverlib.ServerName(u.Host), gomatrixserverlib.ServerName(u.Host),
u.RequestURI(), u.RequestURI(),
) )

View file

@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config:
handle /.well-known/matrix/server { handle /.well-known/matrix/server {
header Content-Type application/json header Content-Type application/json
header Access-Control-Allow-Origin * header Access-Control-Allow-Origin *
respond `"m.server": "matrix.example.com:8448"` respond `{"m.server": "matrix.example.com:8448"}`
} }
handle /.well-known/matrix/client { handle /.well-known/matrix/client {

View file

@ -21,8 +21,8 @@ type FederationInternalAPI interface {
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -30,6 +30,12 @@ type FederationInternalAPI interface {
request *PerformBroadcastEDURequest, request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse, response *PerformBroadcastEDUResponse,
) error ) error
PerformWakeupServers(
ctx context.Context,
request *PerformWakeupServersRequest,
response *PerformWakeupServersResponse,
) error
} }
type ClientFederationAPI interface { type ClientFederationAPI interface {
@ -60,18 +66,18 @@ type RoomserverFederationAPI interface {
// containing only the server names (without information for membership events). // containing only the server names (without information for membership events).
// The response will include this server if they are joined to the room. // The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type KeyserverFederationAPI interface { type KeyserverFederationAPI interface {
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
} }
// an interface for gmsl.FederationClient - contains functions called by federationapi only. // an interface for gmsl.FederationClient - contains functions called by federationapi only.
@ -80,28 +86,28 @@ type FederationClient interface {
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations // Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.
@ -200,6 +206,7 @@ type PerformInviteResponse struct {
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"` ExcludeSelf bool `json:"exclude_self"`
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
@ -213,6 +220,13 @@ type PerformBroadcastEDURequest struct {
type PerformBroadcastEDUResponse struct { type PerformBroadcastEDUResponse struct {
} }
type PerformWakeupServersRequest struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformWakeupServersResponse struct {
}
type InputPublicKeysRequest struct { type InputPublicKeysRequest struct {
Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"`
} }

View file

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return true return true
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View file

@ -120,15 +120,13 @@ func NewInternalAPI(
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := base.Cfg.Global.SigningIdentities()
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, &stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ signingInfo,
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
) )
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(

View file

@ -104,7 +104,7 @@ func TestMain(m *testing.M) {
// Create the federation client. // Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient( s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv, s.config.Matrix.SigningIdentities(),
gomatrixserverlib.WithTransport(transport), gomatrixserverlib.WithTransport(transport),
) )
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
} }
// Get the keys and JSON-ify them. // Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config) keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ") body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv
return keys, nil return keys, nil
} }
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
res.RoomVersion = r.Version res.RoomVersion = r.Version
@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName
} }
return return
} }
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
f.fedClientMutex.Lock() f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
fedCli := gomatrixserverlib.NewFederationClient( fedCli := gomatrixserverlib.NewFederationClient(
serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, cfg.Global.SigningIdentities(),
gomatrixserverlib.WithSkipVerify(true), gomatrixserverlib.WithSkipVerify(true),
) )
@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
t.Errorf("failed to create invite v2 request: %s", err) t.Errorf("failed to create invite v2 request: %s", err)
continue continue
} }
_, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq)
if err == nil { if err == nil {
t.Errorf("expected an error, got none") t.Errorf("expected an error, got none")
continue continue

View file

@ -11,13 +11,13 @@ import (
// client. // client.
func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res gomatrixserverlib.RespEventAuth, err error) { ) (res gomatrixserverlib.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespEventAuth{}, err return gomatrixserverlib.RespEventAuth{}, err
@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth(
} }
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespUserDevices{}, err return gomatrixserverlib.RespUserDevices{}, err
@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices(
} }
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespClaimKeys{}, err return gomatrixserverlib.RespClaimKeys{}, err
@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys(
} }
func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, s, keys) return a.federation.QueryKeys(ctx, origin, s, keys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespQueryKeys{}, err return gomatrixserverlib.RespQueryKeys{}, err
@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys(
} }
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill(
} }
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) { ) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespState{}, err return gomatrixserverlib.RespState{}, err
@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState(
} }
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) { ) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespStateIDs{}, err return gomatrixserverlib.RespStateIDs{}, err
@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs(
} }
func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespMissingEvents{}, err return gomatrixserverlib.RespMissingEvents{}, err
@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents(
} }
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys(
} }
func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
if err != nil { if err != nil {
return res, err return res, err
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err

View file

@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
) (err error) { ) (err error) {
dir, err := r.federation.LookupRoomAlias( dir, err := r.federation.LookupRoomAlias(
ctx, ctx,
r.cfg.Matrix.ServerName,
request.ServerName, request.ServerName,
request.RoomAlias, request.RoomAlias,
) )
@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion, supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{}, unsigned map[string]interface{},
) error { ) error {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
}
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
respMakeJoin, err := r.federation.MakeJoin( respMakeJoin, err := r.federation.MakeJoin(
ctx, ctx,
origin,
serverName, serverName,
roomID, roomID,
userID, userID,
@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Build the join event. // Build the join event.
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
context.Background(), context.Background(),
origin,
serverName, serverName,
event, event,
) )
@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
r.keyRing, r.keyRing,
event, event,
federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
) )
if err != nil { if err != nil {
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,
origin,
roomserverAPI.KindNew, roomserverAPI.KindNew,
respState, respState,
event.Headered(respMakeJoin.RoomVersion), event.Headered(respMakeJoin.RoomVersion),
@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// request. // request.
respPeek, err := r.federation.Peek( respPeek, err := r.federation.Peek(
ctx, ctx,
r.cfg.Matrix.ServerName,
serverName, serverName,
roomID, roomID,
peekID, peekID,
@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName))
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// logrus.Warnf("got respPeek %#v", respPeek) // logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view. // Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI, ctx, r.rsAPI, r.cfg.Matrix.ServerName,
roomserverAPI.KindNew, roomserverAPI.KindNew,
&respState, &respState,
respPeek.LatestEvent.Headered(respPeek.RoomVersion), respPeek.LatestEvent.Headered(respPeek.RoomVersion),
@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
// Deduplicate the server names we were provided. // Deduplicate the server names we were provided.
util.SortAndUnique(request.ServerNames) util.SortAndUnique(request.ServerNames)
@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin,
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := respMakeLeave.LeaveEvent.Build( event, err := respMakeLeave.LeaveEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeLeave.RoomVersion, respMakeLeave.RoomVersion,
@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin,
serverName, serverName,
event, event,
) )
@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender())
if err != nil {
return err
}
if request.Event.StateKey() == nil { if request.Event.StateKey() == nil {
return errors.New("invite must be a state event") return errors.New("invite must be a state event")
} }
@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq)
if err != nil { if err != nil {
return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
} }
@ -648,9 +670,23 @@ func (r *FederationInternalAPI) PerformBroadcastEDU(
return nil return nil
} }
// PerformWakeupServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) (err error) {
r.MarkServersAlive(request.ServerNames)
return nil
}
func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) {
for _, srv := range destinations { for _, srv := range destinations {
// Check the statistics cache for the blacklist status to prevent hitting
// the database unnecessarily.
if r.queues.IsServerBlacklisted(srv) {
_ = r.db.RemoveServerFromBlacklist(srv) _ = r.db.RemoveServerFromBlacklist(srv)
}
r.queues.RetryServer(srv) r.queues.RetryServer(srv)
} }
} }
@ -708,7 +744,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation api.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -738,7 +774,7 @@ func federatedAuthProvider(
// Try to retrieve the event from the server that sent us the send // Try to retrieve the event from the server that sent us the send
// join response. // join response.
tx, txerr := federation.GetEvent(ctx, server, eventID) tx, txerr := federation.GetEvent(ctx, origin, server, eventID)
if txerr != nil { if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
} }

View file

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
if err != nil { if err != nil {
return return
} }

View file

@ -23,6 +23,7 @@ const (
FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest"
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
@ -150,18 +151,32 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
) )
} }
// Handle an instruction to remove the respective servers from being blacklisted.
func (h *httpFederationInternalAPI) PerformWakeupServers(
ctx context.Context,
request *api.PerformWakeupServersRequest,
response *api.PerformWakeupServersResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers,
h.httpClient, ctx, request, response,
)
}
type getUserDevices struct { type getUserDevices struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
UserID string UserID string
} }
func (h *httpFederationInternalAPI) GetUserDevices( func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{ ctx, &getUserDevices{
S: s, S: s,
Origin: origin,
UserID: userID, UserID: userID,
}, },
) )
@ -169,16 +184,18 @@ func (h *httpFederationInternalAPI) GetUserDevices(
type claimKeys struct { type claimKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string OneTimeKeys map[string]map[string]string
} }
func (h *httpFederationInternalAPI) ClaimKeys( func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{ ctx, &claimKeys{
S: s, S: s,
Origin: origin,
OneTimeKeys: oneTimeKeys, OneTimeKeys: oneTimeKeys,
}, },
) )
@ -186,16 +203,18 @@ func (h *httpFederationInternalAPI) ClaimKeys(
type queryKeys struct { type queryKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Keys map[string][]string Keys map[string][]string
} }
func (h *httpFederationInternalAPI) QueryKeys( func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{ ctx, &queryKeys{
S: s, S: s,
Origin: origin,
Keys: keys, Keys: keys,
}, },
) )
@ -203,18 +222,20 @@ func (h *httpFederationInternalAPI) QueryKeys(
type backfill struct { type backfill struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Limit int Limit int
EventIDs []string EventIDs []string
} }
func (h *httpFederationInternalAPI) Backfill( func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{ ctx, &backfill{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Limit: limit, Limit: limit,
EventIDs: eventIDs, EventIDs: eventIDs,
@ -224,18 +245,20 @@ func (h *httpFederationInternalAPI) Backfill(
type lookupState struct { type lookupState struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupState( func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) { ) (gomatrixserverlib.RespState, error) {
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{ ctx, &lookupState{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -245,17 +268,19 @@ func (h *httpFederationInternalAPI) LookupState(
type lookupStateIDs struct { type lookupStateIDs struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) LookupStateIDs( func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) { ) (gomatrixserverlib.RespStateIDs, error) {
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{ ctx, &lookupStateIDs{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
}, },
@ -264,19 +289,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs(
type lookupMissingEvents struct { type lookupMissingEvents struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Missing gomatrixserverlib.MissingEvents Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupMissingEvents( func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{ ctx, &lookupMissingEvents{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Missing: missing, Missing: missing,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -286,16 +313,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents(
type getEvent struct { type getEvent struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEvent( func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{ ctx, &getEvent{
S: s, S: s,
Origin: origin,
EventID: eventID, EventID: eventID,
}, },
) )
@ -303,19 +332,21 @@ func (h *httpFederationInternalAPI) GetEvent(
type getEventAuth struct { type getEventAuth struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEventAuth( func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) { ) (gomatrixserverlib.RespEventAuth, error) {
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{ ctx, &getEventAuth{
S: s, S: s,
Origin: origin,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
@ -351,18 +382,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys(
type eventRelationships struct { type eventRelationships struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion RoomVer gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) MSC2836EventRelationships( func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{ ctx, &eventRelationships{
S: s, S: s,
Origin: origin,
Req: r, Req: r,
RoomVer: roomVersion, RoomVer: roomVersion,
}, },
@ -371,17 +404,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
SuggestedOnly bool SuggestedOnly bool
RoomID string RoomID string
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{ ctx, &spacesReq{
S: dst, S: dst,
Origin: origin,
SuggestedOnly: suggestedOnly, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
}, },

View file

@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
) )
internalAPIMux.Handle(
FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath, FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI( httputil.MakeInternalRPCAPI(
@ -59,7 +64,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices", "FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -70,7 +75,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys", "FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -81,7 +86,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys", "FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -92,7 +97,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIBackfill", "FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -103,7 +108,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupState", "FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -114,7 +119,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs", "FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -125,7 +130,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents", "FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -136,7 +141,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent", "FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID) res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -147,7 +152,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth", "FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -174,7 +179,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships", "FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -185,7 +190,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary", "FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),

View file

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests
@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() {
} }
} }
// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are
// pending events or if forceWakeup is true. This prevents starting the
// queue unnecessarily.
func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) {
eventsPending := func() bool {
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0
}
// NOTE : Only wakeup and notify the queue if there are pending events
// or if forceWakeup is true. Otherwise there is no reason to start the
// queue goroutine and waste resources.
if forceWakeup || eventsPending() {
logrus.Info("Starting queue due to pending events or forceWakeup")
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it // wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work. // that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() { func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep. // NOTE : Send notification before waking queue to prevent a race
oq.wakeQueueIfNeeded() // where the queue was running and stops due to a timeout in between
// checking it and sending the notification.
// Notify the queue that there are events ready to send. // Notify the queue that there are events ready to send.
select { select {
case oq.notify <- struct{}{}: case oq.notify <- struct{}{}:
default: default:
} }
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
} }
// wakeQueueIfNeeded will wake up the destination queue if it is // wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off // not already running.
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) wakeQueueIfNeeded() {
// Clear the backingOff flag and update the backoff metrics if it was set. // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) { if oq.backingOff.CompareAndSwap(true, false) {

View file

@ -15,7 +15,6 @@
package queue package queue
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync" "sync"
@ -46,7 +45,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client fedapi.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
@ -91,7 +90,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient, client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing []*gomatrixserverlib.SigningIdentity,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
disabled: disabled, disabled: disabled,
@ -101,9 +100,12 @@ func NewOutgoingQueues(
origin: origin, origin: origin,
client: client, client: client,
statistics: statistics, statistics: statistics,
signing: signing, signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{},
queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
} }
for _, identity := range signing {
queues.signing[identity.ServerName] = identity
}
// Look up which servers we have pending items for and then rehydrate those queues. // Look up which servers we have pending items for and then rehydrate those queues.
if !disabled { if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{} serverNames := map[gomatrixserverlib.ServerName]struct{}{}
@ -135,14 +137,6 @@ func NewOutgoingQueues(
return queues return queues
} }
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
// around together
type SigningInfo struct {
ServerName gomatrixserverlib.ServerName
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
}
type queuedPDU struct { type queuedPDU struct {
receipt *shared.Receipt receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent pdu *gomatrixserverlib.HeaderedEvent
@ -199,11 +193,10 @@ func (oqs *OutgoingQueues) SendEvent(
log.Trace("Federation is disabled, not sending event") log.Trace("Federation is disabled, not sending event")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -214,7 +207,9 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
@ -288,11 +283,10 @@ func (oqs *OutgoingQueues) SendEDU(
log.Trace("Federation is disabled, not sending EDU") log.Trace("Federation is disabled, not sending EDU")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -303,7 +297,9 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// There is absolutely no guarantee that the EDU will have a room_id // There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does* // field, as it is not required by the spec. However, if it *does*
@ -378,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU(
return nil return nil
} }
// IsServerBlacklisted returns whether or not the provided server is currently
// blacklisted.
func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool {
return oqs.statistics.ForServer(srv).Blacklisted()
}
// RetryServer attempts to resend events to the given server if we had given up. // RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
oqs.statistics.ForServer(srv).RemoveBlacklist()
serverStatistics := oqs.statistics.ForServer(srv)
forceWakeup := serverStatistics.Blacklisted()
serverStatistics.RemoveBlacklist()
serverStatistics.ClearBackoff()
if queue := oqs.getQueue(srv); queue != nil { if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff() queue.wakeQueueIfEventsPending(forceWakeup)
queue.wakeQueueIfNeeded()
} }
} }

View file

@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
} }
rs := &stubFederationRoomServerAPI{} rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist) stats := statistics.NewStatistics(db, failuresUntilBlacklist)
signingInfo := &SigningInfo{ signingInfo := []*gomatrixserverlib.SigningIdentity{
{
KeyID: "ed21019:auto", KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA, PrivateKey: test.PrivateKeyA,
ServerName: "localhost", ServerName: "localhost",
},
} }
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)

View file

@ -83,6 +83,7 @@ func Backfill(
"": eIDs, "": eIDs,
}, },
ServerName: request.Origin(), ServerName: request.Origin(),
VirtualHost: request.Destination(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
@ -123,7 +124,7 @@ func Backfill(
} }
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: request.Destination(),
PDUs: eventJSONs, PDUs: eventJSONs,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -140,6 +140,21 @@ func processInvite(
} }
} }
if event.StateKey() == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The invite event has no state key"),
}
}
_, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)),
}
}
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
@ -175,7 +190,7 @@ func processInvite(
// Sign the event so that other servers will know that we have received the invite. // Sign the event so that other servers will know that we have received the invite.
signedEvent := event.Sign( signedEvent := event.Sign(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
) )
// Add the invite event to the roomserver. // Add the invite event to the roomserver.

View file

@ -131,10 +131,20 @@ func MakeJoin(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion, RoomVersion: verRes.RoomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -134,28 +134,30 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, serverName)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.MessageResponse(http.StatusNotFound, err.Error())
} }
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
var identity *gomatrixserverlib.SigningIdentity
keys.ServerName = cfg.Matrix.ServerName var err error
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil {
if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil {
return nil, err
}
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = cfg.Matrix.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: { cfg.Matrix.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey), Key: gomatrixserverlib.Base64Bytes(publicKey),
}, },
} }
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
@ -165,6 +167,21 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
ExpiredTS: oldVerifyKey.ExpiredAt, ExpiredTS: oldVerifyKey.ExpiredAt,
} }
} }
} else {
if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil {
return nil, err
}
publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = virtualHost.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
virtualHost.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
// TODO: Virtual hosts probably want to be able to specify old signing
// keys too, just in case
}
toSign, err := json.Marshal(keys.ServerKeyFields) toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil { if err != nil {
@ -172,13 +189,9 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
} }
keys.Raw, err = gomatrixserverlib.SignJSON( keys.Raw, err = gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign,
) )
if err != nil { return &keys, err
return nil, err
}
return &keys, nil
} }
func NotaryKeys( func NotaryKeys(
@ -186,6 +199,14 @@ func NotaryKeys(
fsAPI federationAPI.FederationInternalAPI, fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest, req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse { ) util.JSONResponse {
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
if !cfg.Matrix.IsLocalServerName(serverName) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Server name not known"),
}
}
if req == nil { if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
@ -201,7 +222,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys { for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { if k, err := localKeys(cfg, serverName); err == nil {
keyList = append(keyList, *k) keyList = append(keyList, *k)
} else { } else {
return util.ErrorResponse(err) return util.ErrorResponse(err)

View file

@ -13,6 +13,7 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"time" "time"
@ -60,8 +61,18 @@ func MakeLeave(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -42,16 +41,9 @@ func GetProfile(
} }
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)),

View file

@ -83,7 +83,7 @@ func RoomAliasToID(
} }
} }
} else { } else {
resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:

View file

@ -74,7 +74,7 @@ func Setup(
} }
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg) return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {

View file

@ -197,12 +197,12 @@ type txnReq struct {
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface { type txnFederationClient interface {
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) )
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion), event.Headered(roomVersion),
}, },
t.Destination,
t.Origin, t.Origin,
api.DoNotSendToOtherServers, api.DoNotSendToOtherServers,
nil, nil,

View file

@ -147,7 +147,7 @@ type txnFedClient struct {
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
} }
func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) { ) {
fmt.Println("testFederationClient.LookupState", eventID) fmt.Println("testFederationClient.LookupState", eventID)
@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv
res = r res = r
return return
} }
func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID) fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID] r, ok := c.stateIDs[eventID]
if !ok { if !ok {
@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S
res = r res = r
return return
} }
func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID) fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID] r, ok := c.getEvent[eventID]
if !ok { if !ok {
@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN
res = r res = r
return return
} }
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing) return c.getMissingEvents(missing)
} }

View file

@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites(
} }
// Send all the events // Send all the events
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { if err := api.SendEvents(
req.Context(),
rsAPI,
api.KindNew,
evs,
cfg.Matrix.ServerName, // TODO: which virtual host?
"TODO",
cfg.Matrix.ServerName,
nil,
false,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()),
}
}
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey)
if err != nil { if err != nil {
@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite(
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
inviteEvent.Headered(verRes.RoomVersion), inviteEvent.Headered(verRes.RoomVersion),
}, },
request.Destination(),
request.Origin(), request.Origin(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,
@ -341,7 +360,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation federationAPI.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, cfg *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)
@ -357,7 +376,7 @@ func sendToRemoteServer(
} }
for _, server := range remoteServers { for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(ctx, server, builder) err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder)
if err == nil { if err == nil {
return return
} }

View file

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
} }
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
return
}
return return
} }
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) stmt := s.selectJoinedHostsForRoomsStmt
if excludingBlacklisted {
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -117,15 +117,17 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if excludeSelf { if excludeSelf {
for i, server := range servers { for i, server := range servers {
if d.IsLocalServerName(server) { if d.IsLocalServerName(server) {
servers = append(servers[:i], servers[i+1:]...) copy(servers[i:], servers[i+1:])
servers = servers[:len(servers)-1]
break
} }
} }
} }

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs { for i := range roomIDs {
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
query := selectJoinedHostsForRoomsSQL
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) if excludingBlacklisted {
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
}
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -58,7 +58,7 @@ type FederationJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
} }
type FederationBlacklist interface { type FederationBlacklist interface {

8
go.mod
View file

@ -22,12 +22,12 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.4 github.com/nats-io/nats-server/v2 v2.9.6
github.com/nats-io/nats.go v1.19.0 github.com/nats-io/nats.go v1.20.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79

16
go.sum
View file

@ -348,10 +348,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw=
github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g=
github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM=
github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=

View file

@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist")
// Returns an error if something else went wrong // Returns an error if something else went wrong
func QueryAndBuildEvent( func QueryAndBuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
if queryRes == nil { if queryRes == nil {
@ -50,24 +51,24 @@ func QueryAndBuildEvent(
// This can pass through a ErrRoomNoExists to the caller // This can pass through a ErrRoomNoExists to the caller
return nil, err return nil, err
} }
return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes)
} }
// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse
// provided. // provided.
func BuildEvent( func BuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil {
if err != nil {
return nil, err return nil, err
} }
event, err := builder.Build( event, err := builder.Build(
evTime, cfg.ServerName, cfg.KeyID, evTime, identity.ServerName, identity.KeyID,
cfg.PrivateKey, queryRes.RoomVersion, identity.PrivateKey, queryRes.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -35,7 +35,7 @@ type DeviceListUpdateConsumer struct {
durable string durable string
topic string topic string
updater *internal.DeviceListUpdater updater *internal.DeviceListUpdater
serverName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
@ -51,7 +51,7 @@ func NewDeviceListUpdateConsumer(
durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
updater: updater, updater: updater,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
return true return true
} else if serverName == t.serverName { } else if t.isLocalServerName(serverName) {
return true return true
} else if serverName != origin { } else if serverName != origin {
return true return true

View file

@ -37,6 +37,7 @@ type SigningKeyUpdateConsumer struct {
topic string topic string
keyAPI *internal.KeyInternalAPI keyAPI *internal.KeyInternalAPI
cfg *config.KeyServer cfg *config.KeyServer
isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
@ -53,6 +54,7 @@ func NewSigningKeyUpdateConsumer(
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
keyAPI: keyAPI, keyAPI: keyAPI,
cfg: cfg, cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
logrus.WithError(err).Error("failed to split user id") logrus.WithError(err).Error("failed to split user id")
return true return true
} else if serverName == t.cfg.Matrix.ServerName { } else if t.isLocalServerName(serverName) {
logrus.Warn("dropping device key update from ourself") logrus.Warn("dropping device key update from ourself")
return true return true
} else if serverName != origin { } else if serverName != origin {

View file

@ -96,6 +96,7 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.KeyserverFederationAPI fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select. // block on or timeout via a select.
@ -139,6 +140,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process, process: process,
@ -148,6 +150,7 @@ func NewDeviceListUpdater(
api: api, api: api,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
thisServer: thisServer,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool), userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{}, userIDToChanMu: &sync.Mutex{},
@ -435,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
"server_name": serverName, "server_name": serverName,
"user_id": userID, "user_id": userID,
}) })
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err return time.Minute * 10, err

View file

@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil) _, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient( fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, []*gomatrixserverlib.SigningIdentity{
{
ServerName: gomatrixserverlib.ServerName("example.test"),
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
PrivateKey: pkey,
},
},
) )
fedClient.Client = *gomatrixserverlib.NewClient( fedClient.Client = *gomatrixserverlib.NewClient(
gomatrixserverlib.WithTransport(&roundTripper{tripper}), gomatrixserverlib.WithTransport(&roundTripper{tripper}),
@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -33,12 +33,13 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
nested[userID] = val nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested domainToDeviceKeys[string(serverName)] = nested
} }
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys // claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local) keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
} }
} }
delete(domainToDeviceKeys, string(a.ThisServer)) delete(domainToDeviceKeys, domain)
} }
if len(domainToDeviceKeys) > 0 { if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
@ -142,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
defer wg.Done() defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{} domains := map[string]struct{}{}
for domain := range domainToDeviceKeys { for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
} }
for domain := range domainToCrossSigningKeys { for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
@ -555,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 { if len(devKeys) == 0 {
return return
} }
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil { if err == nil {
resultCh <- &queryKeysResp resultCh <- &queryKeysResp
return return
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
if serverName != a.ThisServer { if !a.Cfg.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users continue // ignore remote users
} }
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {

View file

@ -54,11 +54,11 @@ func NewInternalAPI(
} }
ap := &internal.KeyInternalAPI{ ap := &internal.KeyInternalAPI{
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, Cfg: cfg,
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {

View file

@ -96,6 +96,7 @@ type TransactionID struct {
type InputRoomEventsRequest struct { type InputRoomEventsRequest struct {
InputRoomEvents []InputRoomEvent `json:"input_room_events"` InputRoomEvents []InputRoomEvent `json:"input_room_events"`
Asynchronous bool `json:"async"` Asynchronous bool `json:"async"`
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// InputRoomEventsResponse is a response to InputRoomEvents // InputRoomEventsResponse is a response to InputRoomEvents

View file

@ -149,6 +149,8 @@ type PerformBackfillRequest struct {
Limit int `json:"limit"` Limit int `json:"limit"`
// The server interested in the events. // The server interested in the events.
ServerName gomatrixserverlib.ServerName `json:"server_name"` ServerName gomatrixserverlib.ServerName `json:"server_name"`
// Which virtual host are we doing this for?
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order.

View file

@ -26,7 +26,7 @@ import (
func SendEvents( func SendEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
kind Kind, events []*gomatrixserverlib.HeaderedEvent, kind Kind, events []*gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, virtualHost, origin gomatrixserverlib.ServerName,
sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID,
async bool, async bool,
) error { ) error {
@ -40,14 +40,15 @@ func SendEvents(
TransactionID: txnID, TransactionID: txnID,
} }
} }
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendEventWithState writes an event with the specified kind to the roomserver // SendEventWithState writes an event with the specified kind to the roomserver
// with the state at the event as KindOutlier before it. Will not send any event that is // with the state at the event as KindOutlier before it. Will not send any event that is
// marked as `true` in haveEventIDs. // marked as `true` in haveEventIDs.
func SendEventWithState( func SendEventWithState(
ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName, kind Kind,
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
@ -85,17 +86,19 @@ func SendEventWithState(
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,
}) })
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendInputRoomEvents to the roomserver. // SendInputRoomEvents to the roomserver.
func SendInputRoomEvents( func SendInputRoomEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName,
ires []InputRoomEvent, async bool, ires []InputRoomEvent, async bool,
) error { ) error {
request := InputRoomEventsRequest{ request := InputRoomEventsRequest{
InputRoomEvents: ires, InputRoomEvents: ires,
Asynchronous: async, Asynchronous: async,
VirtualHost: virtualHost,
} }
var response InputRoomEventsResponse var response InputRoomEventsResponse
if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {

View file

@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest, request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
_, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
sender = ev.Sender() sender = ev.Sender()
} }
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender)
if err != nil {
return err
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
Sender: sender, Sender: sender,
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes)
if err != nil { if err != nil {
return err return err
} }
err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil { if err != nil {
return err return err
} }

View file

@ -89,7 +89,7 @@ func NewRoomserverAPI(
Queryer: &query.Queryer{ Queryer: &query.Queryer{
DB: roomserverDB, DB: roomserverDB,
Cache: base.Caches, Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName, IsLocalServerName: base.Cfg.Global.IsLocalServerName,
ServerACLs: serverACLs, ServerACLs: serverACLs,
}, },
// perform-er structs get initialised when we have a federation sender to use // perform-er structs get initialised when we have a federation sender to use
@ -104,6 +104,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI r.fsAPI = fsAPI
r.KeyRing = keyRing r.KeyRing = keyRing
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.Cfg.Matrix.ServerName)
// If federation is enabled, but we don't have a signing key, bail.
if err != nil && !r.Cfg.Matrix.DisableFederation {
logrus.Panic(err)
}
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
Cfg: &r.Base.Cfg.RoomServer, Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base, Base: r.Base,
@ -115,6 +121,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
NATSClient: r.NATSClient, NATSClient: r.NATSClient,
Durable: nats.Durable(r.Durable), Durable: nats.Durable(r.Durable),
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.Cfg.Matrix.ServerName,
SigningIdentity: identity,
FSAPI: fsAPI, FSAPI: fsAPI,
KeyRing: keyRing, KeyRing: keyRing,
ACLs: r.ServerACLs, ACLs: r.ServerACLs,
@ -127,7 +134,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Joiner = &perform.Joiner{ r.Joiner = &perform.Joiner{
ServerName: r.Cfg.Matrix.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -163,7 +169,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
DB: r.DB, DB: r.DB,
} }
r.Backfiller = &perform.Backfiller{ r.Backfiller = &perform.Backfiller{
ServerName: r.ServerName, IsLocalServerName: r.Cfg.Matrix.IsLocalServerName,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
KeyRing: r.KeyRing, KeyRing: r.KeyRing,

View file

@ -0,0 +1,56 @@
package helpers
import (
"context"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/roomserver/storage"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func TestIsInvitePendingWithoutNID(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat))
_ = bob
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
_, db, close := mustCreateDatabase(t, dbType)
defer close()
// store all events
var authNIDs []types.EventNID
for _, x := range room.Events() {
evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false)
assert.NoError(t, err)
authNIDs = append(authNIDs, evNID)
}
// Alice should have no pending invites and should have a NID
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID)
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
// Bob should have no pending invites and receive a new NID
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID)
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
})
}

View file

@ -81,6 +81,7 @@ type Inputer struct {
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable nats.SubOpt
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
SigningIdentity *gomatrixserverlib.SigningIdentity
FSAPI fedapi.RoomserverFederationAPI FSAPI fedapi.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs ACLs *acls.ServerACLs
@ -281,7 +282,11 @@ func (w *worker) _next() {
// a string, because we might want to return that to the caller if // a string, because we might want to return that to the caller if
// it was a synchronous request. // it was a synchronous request.
var errString string var errString string
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { if err = w.r.processRoomEvent(
w.r.ProcessContext.Context(),
gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")),
&inputRoomEvent,
); err != nil {
switch err.(type) { switch err.(type) {
case types.RejectedError: case types.RejectedError:
// Don't send events that were rejected to Sentry // Don't send events that were rejected to Sentry
@ -361,6 +366,7 @@ func (r *Inputer) queueInputRoomEvents(
if replyTo != "" { if replyTo != "" {
msg.Header.Set("sync", replyTo) msg.Header.Set("sync", replyTo)
} }
msg.Header.Set("virtual_host", string(request.VirtualHost))
msg.Data, err = json.Marshal(e) msg.Data, err = json.Marshal(e)
if err != nil { if err != nil {
return nil, fmt.Errorf("json.Marshal: %w", err) return nil, fmt.Errorf("json.Marshal: %w", err)

View file

@ -24,9 +24,9 @@ import (
"fmt" "fmt"
"time" "time"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -71,6 +71,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
virtualHost gomatrixserverlib.ServerName,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) error { ) error {
select { select {
@ -168,6 +169,7 @@ func (r *Inputer) processRoomEvent(
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
ExcludeBlacklisted: true,
} }
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
@ -202,7 +204,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -268,6 +270,7 @@ func (r *Inputer) processRoomEvent(
if len(serverRes.ServerNames) > 0 { if len(serverRes.ServerNames) > 0 {
missingState := missingStateReq{ missingState := missingStateReq{
origin: input.Origin, origin: input.Origin,
virtualHost: virtualHost,
inputer: r, inputer: r,
db: r.DB, db: r.DB,
roomInfo: roomInfo, roomInfo: roomInfo,
@ -413,6 +416,13 @@ func (r *Inputer) processRoomEvent(
} }
} }
// Handle remote room upgrades, e.g. remove published room
if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil {
return fmt.Errorf("failed to handle remote room upgrade: %w", err)
}
}
// processing this event resulted in an event (which may not be the one we're processing) // processing this event resulted in an event (which may not be the one we're processing)
// being redacted. We are guaranteed to have both sides (the redaction/redacted event), // being redacted. We are guaranteed to have both sides (the redaction/redacted event),
// so notify downstream components to redact this event - they should have it if they've // so notify downstream components to redact this event - they should have it if they've
@ -442,6 +452,13 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
// handleRemoteRoomUpgrade updates published rooms and room aliases
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error {
oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
}
// processStateBefore works out what the state is before the event and // processStateBefore works out what the state is before the event and
// then checks the event auths against the state at the time. It also // then checks the event auths against the state at the time. It also
// tries to determine what the history visibility was of the event at // tries to determine what the history visibility was of the event at
@ -547,6 +564,7 @@ func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
logger *logrus.Entry, logger *logrus.Entry,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
virtualHost gomatrixserverlib.ServerName,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
known map[string]*types.Event, known map[string]*types.Event,
@ -597,7 +615,7 @@ func (r *Inputer) fetchAuthEvents(
// Request the entire auth chain for the event in question. This should // Request the entire auth chain for the event in question. This should
// contain all of the auth events — including ones that we already know — // contain all of the auth events — including ones that we already know —
// so we'll need to filter through those in the next section. // so we'll need to filter through those in the next section.
res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID())
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err)
continue continue
@ -792,7 +810,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err return err
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
return err return err
} }

View file

@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
type missingStateReq struct { type missingStateReq struct {
log *logrus.Entry log *logrus.Entry
virtualHost gomatrixserverlib.ServerName
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db storage.Database
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState(
}) })
} }
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil {
if _, ok := err.(types.RejectedError); !ok { if _, ok := err.(types.RejectedError); !ok {
return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err)
} }
@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
for _, server := range t.servers { for _, server := range t.servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{
Limit: 20, Limit: 20,
// The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events.
EarliestEvents: latestEvents, EarliestEvents: latestEvents,
@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState")
defer span.Finish() defer span.Finish()
state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5)
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20)
stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID)
reqcancel() reqcancel()
if err == nil { if err == nil {
break break
@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, cancel := context.WithTimeout(ctx, time.Second*30) reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {
t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID")
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {

View file

@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
continue
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState(
req *api.PerformAdminDownloadStateRequest, req *api.PerformAdminDownloadStateRequest,
res *api.PerformAdminDownloadStateResponse, res *api.PerformAdminDownloadStateResponse,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err),
}
return nil
}
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, fwdExtremity := range fwdExtremities { for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.RespState var state gomatrixserverlib.RespState
state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState(
Depth: depth, Depth: depth,
} }
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -37,7 +37,7 @@ import (
const maxBackfillServers = 5 const maxBackfillServers = 5
type Backfiller struct { type Backfiller struct {
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
DB storage.Database DB storage.Database
FSAPI federationAPI.RoomserverFederationAPI FSAPI federationAPI.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill(
// if we are requesting the backfill then we need to do a federation hit // if we are requesting the backfill then we need to do a federation hit
// TODO: we could be more sensible and fetch as many events we already have then request the rest // TODO: we could be more sensible and fetch as many events we already have then request the rest
// which is what the syncapi does already. // which is what the syncapi does already.
if request.ServerName == r.ServerName { if r.IsLocalServerName(request.ServerName) {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
} }
// someone else is requesting the backfill, try to service their request. // someone else is requesting the backfill, try to service their request.
@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
} }
requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
// Request 100 items regardless of what the query asks for. // Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this. // We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
// (so we don't need to hit /state_ids which the test has no listener for) // (so we don't need to hit /state_ids which the test has no listener for)
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
return err return err
} }
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil {
// attempt to fetch the missing events // attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost)
// try again // try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true)
if err != nil { if err != nil {
@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
// best effort. // best effort.
func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
backfillRequester *backfillRequester, stateIDs []string) { backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) {
servers := backfillRequester.servers servers := backfillRequester.servers
@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue // already found continue // already found
} }
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
res, err := r.FSAPI.GetEvent(ctx, srv, id) res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to get event from server") logger.WithError(err).Warn("failed to get event from server")
continue continue
@ -243,7 +245,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
type backfillRequester struct { type backfillRequester struct {
db storage.Database db storage.Database
fsAPI federationAPI.RoomserverFederationAPI fsAPI federationAPI.RoomserverFederationAPI
thisServer gomatrixserverlib.ServerName virtualHost gomatrixserverlib.ServerName
isLocalServerName func(gomatrixserverlib.ServerName) bool
preferServer map[gomatrixserverlib.ServerName]bool preferServer map[gomatrixserverlib.ServerName]bool
bwExtrems map[string][]string bwExtrems map[string][]string
@ -255,7 +258,9 @@ type backfillRequester struct {
} }
func newBackfillRequester( func newBackfillRequester(
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
virtualHost gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName,
) *backfillRequester { ) *backfillRequester {
preferServer := make(map[gomatrixserverlib.ServerName]bool) preferServer := make(map[gomatrixserverlib.ServerName]bool)
@ -265,7 +270,8 @@ func newBackfillRequester(
return &backfillRequester{ return &backfillRequester{
db: db, db: db,
fsAPI: fsAPI, fsAPI: fsAPI,
thisServer: thisServer, virtualHost: virtualHost,
isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]*gomatrixserverlib.Event), eventIDMap: make(map[string]*gomatrixserverlib.Event),
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
@ -450,7 +456,7 @@ FindSuccessor:
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost)
b.historyVisiblity = visibility b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@ -477,7 +483,7 @@ FindSuccessor:
} }
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {
if server == b.thisServer { if b.isLocalServerName(server) {
continue continue
} }
if b.preferServer[server] { // insert at the front if b.preferServer[server] { // insert at the front
@ -496,10 +502,10 @@ FindSuccessor:
// Backfill performs a backfill request to the given server. // Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string,
limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) {
tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs)
return tx, err return tx, err
} }

View file

@ -40,7 +40,6 @@ import (
) )
type Joiner struct { type Joiner struct {
ServerName gomatrixserverlib.ServerName
Cfg *config.RoomServer Cfg *config.RoomServer
FSAPI fsAPI.RoomserverFederationAPI FSAPI fsAPI.RoomserverFederationAPI
RSAPI rsAPI.RoomserverInternalAPI RSAPI rsAPI.RoomserverInternalAPI
@ -198,7 +197,7 @@ func (r *Joiner) performJoinRoomByID(
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorBadRequest, Code: rsAPI.PerformErrorBadRequest,
@ -306,7 +305,7 @@ func (r *Joiner) performJoinRoomByID(
// locally on the homeserver. // locally on the homeserver.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb)
switch err { switch err {
case nil: case nil:
@ -433,7 +432,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
} }
func buildEvent( func buildEvent(
ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, ctx context.Context, db storage.Database, cfg *config.Global,
senderDomain gomatrixserverlib.ServerName,
builder *gomatrixserverlib.EventBuilder,
) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -461,7 +462,12 @@ func buildEvent(
} }
} }
ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) identity, err := cfg.SigningIdentityFor(senderDomain)
if err != nil {
return nil, nil, err
}
ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -162,21 +162,21 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eb.SetUnsigned: %w", err) return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender)
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", eb.Sender)
}
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb)
if err != nil { if err != nil {
return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender())
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
// Give our leave event to the roomserver input stream. The // Give our leave event to the roomserver input stream. The
// roomserver will process the membership change and notify // roomserver will process the membership change and notify
// downstream automatically. // downstream automatically.

View file

@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade(
) (string, *api.PerformError) { ) (string, *api.PerformError) {
roomID := req.RoomID roomID := req.RoomID
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", &api.PerformError{
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
} }
if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err),
} }
@ -595,8 +595,21 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err),
} }
} }
// Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender)
if serr != nil {
return nil, &api.PerformError{
Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err),
}
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return nil, &api.PerformError{
Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, evTime, r.URSAPI, &queryRes) headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &api.PerformError{ return nil, &api.PerformError{
Code: api.PerformErrorNoRoom, Code: api.PerformErrorNoRoom,
@ -686,7 +699,7 @@ func (r *Upgrader) sendHeaderedEvent(
Origin: serverName, Origin: serverName,
SendAsServer: sendAsServer, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err),
} }

View file

@ -39,7 +39,7 @@ import (
type Queryer struct { type Queryer struct {
DB storage.Database DB storage.Database
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
} }
@ -392,7 +392,7 @@ func (r *Queryer) QueryServerJoinedToRoom(
} }
response.RoomExists = true response.RoomExists = true
if request.ServerName == r.ServerName || request.ServerName == "" { if r.IsLocalServerName(request.ServerName) || request.ServerName == "" {
response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)

View file

@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) {
// SetFederationAPI starts the room event input consumer // SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil) rsAPI.SetFederationAPI(nil, nil)
// Create the room // Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }

View file

@ -172,4 +172,5 @@ type Database interface {
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
} }

View file

@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
eventStateKeys = util.UniqueStrings(eventStateKeys)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,17 +112,25 @@ func (d *Database) eventStateKeyNIDs(
result[eventStateKey] = nid result[eventStateKey] = nid
} }
// We received some nids, but are still missing some, work out which and create them // We received some nids, but are still missing some, work out which and create them
if len(eventStateKeys) < len(result) { if len(eventStateKeys) > len(result) {
var nid types.EventStateKeyNID
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
for _, eventStateKey := range eventStateKeys { for _, eventStateKey := range eventStateKeys {
if _, ok := result[eventStateKey]; ok { if _, ok := result[eventStateKey]; ok {
continue continue
} }
nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey)
nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey)
if err != nil { if err != nil {
return result, err return err
} }
result[eventStateKey] = nid result[eventStateKey] = nid
} }
return nil
})
if err != nil {
return nil, err
}
} }
return result, nil return result, nil
} }
@ -1399,6 +1408,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
}) })
} }
func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// un-publish old room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil {
return fmt.Errorf("failed to unpublish room: %w", err)
}
// publish new room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil {
return fmt.Errorf("failed to publish room: %w", err)
}
// Migrate any existing room aliases
aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID)
if err != nil {
return fmt.Errorf("failed to get room aliases: %w", err)
}
for _, alias := range aliases {
if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil {
return fmt.Errorf("failed to remove room alias: %w", err)
}
if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil {
return fmt.Errorf("failed to set room alias: %w", err)
}
}
return nil
})
}
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package! // it should live in this package!

View file

@ -364,10 +364,10 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client {
// CreateFederationClient creates a new federation client. Should only be called // CreateFederationClient creates a new federation client. Should only be called
// once per component. // once per component.
func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient {
identities := b.Cfg.Global.SigningIdentities()
if b.Cfg.Global.DisableFederation { if b.Cfg.Global.DisableFederation {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, identities, gomatrixserverlib.WithTransport(noOpHTTPTransport),
gomatrixserverlib.WithTransport(noOpHTTPTransport),
) )
} }
opts := []gomatrixserverlib.ClientOption{ opts := []gomatrixserverlib.ClientOption{
@ -379,8 +379,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli
opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache))
} }
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, identities, opts...,
b.Cfg.Global.PrivateKey, opts...,
) )
client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString()))
return client return client

View file

@ -231,6 +231,21 @@ func loadConfig(
return nil, err return nil, err
} }
for _, v := range c.Global.VirtualHosts {
if v.KeyValidityPeriod == 0 {
v.KeyValidityPeriod = c.Global.KeyValidityPeriod
}
if v.PrivateKeyPath == "" || v.PrivateKey == nil || v.KeyID == "" {
v.KeyID = c.Global.KeyID
v.PrivateKey = c.Global.PrivateKey
continue
}
privateKeyPath := absPath(basePath, v.PrivateKeyPath)
if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil {
return nil, err
}
}
for _, key := range c.Global.OldVerifyKeys { for _, key := range c.Global.OldVerifyKeys {
switch { switch {
case key.PrivateKeyPath != "": case key.PrivateKeyPath != "":

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"fmt"
"math/rand" "math/rand"
"strconv" "strconv"
"strings" "strings"
@ -11,22 +12,16 @@ import (
) )
type Global struct { type Global struct {
// The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. // Signing identity contains the server name, private key and key ID of
ServerName gomatrixserverlib.ServerName `yaml:"server_name"` // the deployment.
gomatrixserverlib.SigningIdentity `yaml:",inline"`
// The secondary server names, used for virtual hosting. // The secondary server names, used for virtual hosting.
SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` VirtualHosts []*VirtualHost `yaml:"virtual_hosts"`
// Path to the private key which will be used to sign requests and events. // Path to the private key which will be used to sign requests and events.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`
// The private key which will be used to sign requests and events.
PrivateKey ed25519.PrivateKey `yaml:"-"`
// An arbitrary string used to uniquely identify the PrivateKey. Must start with the
// prefix "ed25519:".
KeyID gomatrixserverlib.KeyID `yaml:"-"`
// Information about old private keys that used to be used to sign requests and // Information about old private keys that used to be used to sign requests and
// events on this domain. They will not be used but will be advertised to other // events on this domain. They will not be used but will be advertised to other
// servers that ask for them to help verify old events. // servers that ask for them to help verify old events.
@ -114,6 +109,10 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.server_name", string(c.ServerName))
checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath))
for _, v := range c.VirtualHosts {
v.Verify(configErrs)
}
c.JetStream.Verify(configErrs, isMonolith) c.JetStream.Verify(configErrs, isMonolith)
c.Metrics.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith)
c.Sentry.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith)
@ -127,14 +126,108 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool
if c.ServerName == serverName { if c.ServerName == serverName {
return true return true
} }
for _, secondaryName := range c.SecondaryServerNames { for _, v := range c.VirtualHosts {
if secondaryName == serverName { if v.ServerName == serverName {
return true return true
} }
} }
return false return false
} }
func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) {
u, s, err := gomatrixserverlib.SplitID(sigil, id)
if err != nil {
return u, s, err
}
if !c.IsLocalServerName(s) {
return u, s, fmt.Errorf("server name %q not known", s)
}
return u, s, nil
}
func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost {
for _, v := range c.VirtualHosts {
if v.ServerName == serverName {
return v
}
}
return nil
}
func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost {
for _, v := range c.VirtualHosts {
if v.ServerName == serverName {
return v
}
for _, h := range v.MatchHTTPHosts {
if h == serverName {
return v
}
}
}
return nil
}
func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) {
for _, id := range c.SigningIdentities() {
if id.ServerName == serverName {
return id, nil
}
}
return nil, fmt.Errorf("no signing identity for %q", serverName)
}
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1)
identities = append(identities, &c.SigningIdentity)
for _, v := range c.VirtualHosts {
identities = append(identities, &v.SigningIdentity)
}
return identities
}
type VirtualHost struct {
// Signing identity contains the server name, private key and key ID of
// the virtual host.
gomatrixserverlib.SigningIdentity `yaml:",inline"`
// Path to the private key. If not specified, the default global private key
// will be used instead.
PrivateKeyPath Path `yaml:"private_key"`
// How long a remote server can cache our server key for before requesting it again.
// Increasing this number will reduce the number of requests made by remote servers
// for our key, but increases the period a compromised key will be considered valid
// by remote servers.
// Defaults to 24 hours.
KeyValidityPeriod time.Duration `yaml:"key_validity_period"`
// Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs
// to match all delegated names, likely including the port number too if
// the well-known delegation includes that also.
MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"`
// Is registration enabled on this virtual host?
AllowRegistration bool `yaml:"allow_registration"`
// Is guest registration enabled on this virtual host?
AllowGuests bool `yaml:"allow_guests"`
}
func (v *VirtualHost) Verify(configErrs *ConfigErrors) {
checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName))
}
// RegistrationAllowed returns two bools, the first states whether registration
// is allowed for this virtual host and the second states whether guests are
// allowed for this virtual host.
func (v *VirtualHost) RegistrationAllowed() (bool, bool) {
if v == nil {
return false, false
}
return v.AllowRegistration, v.AllowGuests
}
type OldVerifyKeys struct { type OldVerifyKeys struct {
// Path to the private key. // Path to the private key.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`

View file

@ -2,6 +2,7 @@ package jetstream
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
@ -72,6 +73,9 @@ func JetStreamConsumer(
// just timed out and we should try again. // just timed out and we should try again.
continue continue
} }
} else if errors.Is(err, nats.ErrConsumerDeleted) {
// The consumer was deleted so stop.
return
} else { } else {
// Something else went wrong, so we'll panic. // Something else went wrong, so we'll panic.
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
serversToQuery := rc.getServersForEventID(parentID) serversToQuery := rc.getServersForEventID(parentID)
var result *MSC2836EventRelationshipsResponse var result *MSC2836EventRelationshipsResponse
for _, srv := range serversToQuery { for _, srv := range serversToQuery {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: parentID, EventID: parentID,
Direction: "down", Direction: "down",
Limit: 100, Limit: 100,
@ -484,7 +484,7 @@ func walkThread(
// MSC2836EventRelationships performs an /event_relationships request to a remote server // MSC2836EventRelationships performs an /event_relationships request to a remote server
func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: eventID, EventID: eventID,
DepthFirst: rc.req.DepthFirst, DepthFirst: rc.req.DepthFirst,
Direction: rc.req.Direction, Direction: rc.req.Direction,
@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
}) })
} }
// we've got the data by this point so use a background context // we've got the data by this point so use a background context
err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
} }

Some files were not shown because too many files have changed in this diff Show more