Merge branch 'main' into moderncsqlite2

This commit is contained in:
0x1a8510f2 2022-10-31 17:35:41 +00:00 committed by GitHub
commit debb8348dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
88 changed files with 1405 additions and 483 deletions

128
.github/workflows/schedules.yaml vendored Normal file
View file

@ -0,0 +1,128 @@
name: Scheduled
on:
schedule:
- cron: '0 0 * * *' # every day at midnight
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
# run go test with different go versions
test:
timeout-minutes: 20
name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
fail-fast: false
matrix:
go: ["1.18", "1.19"]
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-race-
- run: go test -race ./...
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
# Dummy step to gate other tests on without repeating the whole list
initial-tests-done:
name: Initial tests passed
needs: [test]
runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps:
- name: Check initial tests passed
uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
# run Sytest in different variations
sytest:
timeout-minutes: 60
needs: initial-tests-done
name: "Sytest (${{ matrix.label }})"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- label: SQLite
- label: SQLite, full HTTP APIs
api: full-http
- label: PostgreSQL
postgres: postgres
- label: PostgreSQL, full HTTP APIs
postgres: postgres
api: full-http
container:
image: matrixdotorg/sytest-dendrite:latest
volumes:
- ${{ github.workspace }}:/src
env:
POSTGRES: ${{ matrix.postgres && 1}}
API: ${{ matrix.api && 1 }}
SYTEST_BRANCH: ${{ github.head_ref }}
RACE_DETECTION: 1
steps:
- uses: actions/checkout@v2
- name: Run Sytest
run: /bootstrap.sh dendrite
working-directory: /src
- name: Summarise results.tap
if: ${{ always() }}
run: /sytest/scripts/tap_to_gha.pl /logs/results.tap
- name: Sytest List Maintenance
if: ${{ always() }}
run: /src/show-expected-fail-tests.sh /logs/results.tap /src/sytest-whitelist /src/sytest-blacklist
continue-on-error: true # not fatal
- name: Are We Synapse Yet?
if: ${{ always() }}
run: /src/are-we-synapse-yet.py /logs/results.tap -v
continue-on-error: true # not fatal
- name: Upload Sytest logs
uses: actions/upload-artifact@v2
if: ${{ always() }}
with:
name: Sytest Logs - ${{ job.status }} - (Dendrite, ${{ join(matrix.*, ', ') }})
path: |
/logs/results.tap
/logs/**/*.log*

View file

@ -1,5 +1,25 @@
# Changelog # Changelog
## Dendrite 0.10.5 (2022-10-31)
### Features
* It is now possible to use hCaptcha instead of reCAPTCHA for protecting registration
* A new `auto_join_rooms` configuration option has been added for automatically joining new users to a set of rooms
* A new `/_dendrite/admin/downloadState/{serverName}/{roomID}` endpoint has been added, which allows a server administrator to attempt to repair a room with broken room state by downloading a state snapshot from another federated server in the room
### Fixes
* Querying cross-signing keys for users should now be considerably faster
* A bug in state resolution where some events were not correctly selected for third-party invites has been fixed
* A bug in state resolution which could result in `not in room` event rejections has been fixed
* When accepting a DM invite, it should now be possible to see messages that were sent before the invite was accepted
* Claiming remote E2EE one-time keys has been refactored and should be more reliable now
* Various fixes have been made to the `/members` endpoint, which may help with E2EE reliability and clients rendering memberships
* A race condition in the federation API destination queues has been fixed when associating queued events with remote server destinations
* A bug in the sync API where too many events were selected resulting in high CPU usage has been fixed
* Configuring the avatar URL for the Server Notices user should work correctly now
## Dendrite 0.10.4 (2022-10-21) ## Dendrite 0.10.4 (2022-10-21)
### Features ### Features

View file

@ -0,0 +1,25 @@
FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base
WORKDIR /build
COPY . /build
RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-demo-yggdrasil
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest
LABEL org.opencontainers.image.title="Dendrite (Yggdrasil demo)"
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
LABEL org.opencontainers.image.licenses="Apache-2.0"
COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite
WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-demo-yggdrasil"]

View file

@ -101,18 +101,46 @@ func (m *DendriteMonolith) SessionCount() int {
return len(m.PineconeQUIC.Protocol("matrix").Sessions()) return len(m.PineconeQUIC.Protocol("matrix").Sessions())
} }
func (m *DendriteMonolith) RegisterNetworkInterface(name string, index int, mtu int, up bool, broadcast bool, loopback bool, pointToPoint bool, multicast bool, addrs string) { type InterfaceInfo struct {
m.PineconeMulticast.RegisterInterface(pineconeMulticast.InterfaceInfo{ Name string
Name: name, Index int
Index: index, Mtu int
Mtu: mtu, Up bool
Up: up, Broadcast bool
Broadcast: broadcast, Loopback bool
Loopback: loopback, PointToPoint bool
PointToPoint: pointToPoint, Multicast bool
Multicast: multicast, Addrs string
Addrs: addrs, }
})
type InterfaceRetriever interface {
CacheCurrentInterfaces() int
GetCachedInterface(index int) *InterfaceInfo
}
func (m *DendriteMonolith) RegisterNetworkCallback(intfCallback InterfaceRetriever) {
callback := func() []pineconeMulticast.InterfaceInfo {
count := intfCallback.CacheCurrentInterfaces()
intfs := []pineconeMulticast.InterfaceInfo{}
for i := 0; i < count; i++ {
iface := intfCallback.GetCachedInterface(i)
if iface != nil {
intfs = append(intfs, pineconeMulticast.InterfaceInfo{
Name: iface.Name,
Index: iface.Index,
Mtu: iface.Mtu,
Up: iface.Up,
Broadcast: iface.Broadcast,
Loopback: iface.Loopback,
PointToPoint: iface.PointToPoint,
Multicast: iface.Multicast,
Addrs: iface.Addrs,
})
}
}
return intfs
}
m.PineconeMulticast.RegisterNetworkCallback(callback)
} }
func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {

View file

@ -74,7 +74,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."), JSON: jsonerror.BadJSON("A password must be supplied."),
} }
} }
localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName) localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,

View file

@ -70,7 +70,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
if err != nil { if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
if domain != cfg.Matrix.ServerName { if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("User ID must belong to this server."), JSON: jsonerror.MissingArgument("User ID must belong to this server."),
@ -169,7 +169,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
if err != nil { if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
if domain == cfg.Matrix.ServerName { if cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), JSON: jsonerror.InvalidParam("Can not mark local device list as stale"),
@ -191,3 +191,43 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
JSON: struct{}{}, JSON: struct{}{},
} }
} }
func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID, ok := vars["roomID"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting room ID."),
}
}
serverName, ok := vars["serverName"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting remote server name."),
}
}
res := &roomserverAPI.PerformAdminDownloadStateResponse{}
if err := rsAPI.PerformAdminDownloadState(
req.Context(),
&roomserverAPI.PerformAdminDownloadStateRequest{
UserID: device.UserID,
RoomID: roomID,
ServerName: gomatrixserverlib.ServerName(serverName),
},
res,
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{},
}
}

View file

@ -31,8 +31,7 @@ const recaptchaTemplate = `
<title>Authentication</title> <title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1, <meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script src="https://www.google.com/recaptcha/api.js" <script src="{{.apiJsUrl}}" async defer></script>
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> <script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
<script> <script>
function captchaDone() { function captchaDone() {
@ -51,8 +50,8 @@ function captchaDone() {
Please verify that you're not a robot. Please verify that you're not a robot.
</p> </p>
<input type="hidden" name="session" value="{{.session}}" /> <input type="hidden" name="session" value="{{.session}}" />
<div class="g-recaptcha" <div class="{{.sitekeyClass}}"
data-sitekey="{{.siteKey}}" data-sitekey="{{.sitekey}}"
data-callback="captchaDone"> data-callback="captchaDone">
</div> </div>
<noscript> <noscript>
@ -114,9 +113,12 @@ func AuthFallback(
serveRecaptcha := func() { serveRecaptcha := func() {
data := map[string]string{ data := map[string]string{
"myUrl": req.URL.String(), "myUrl": req.URL.String(),
"session": sessionID, "session": sessionID,
"siteKey": cfg.RecaptchaPublicKey, "apiJsUrl": cfg.RecaptchaApiJsUrl,
"sitekey": cfg.RecaptchaPublicKey,
"sitekeyClass": cfg.RecaptchaSitekeyClass,
"formField": cfg.RecaptchaFormField,
} }
serveTemplate(w, recaptchaTemplate, data) serveTemplate(w, recaptchaTemplate, data)
} }
@ -155,7 +157,7 @@ func AuthFallback(
return &res return &res
} }
response := req.Form.Get("g-recaptcha-response") response := req.Form.Get(cfg.RecaptchaFormField)
if err := validateRecaptcha(cfg, response, clientIP); err != nil { if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err) util.GetLogger(req.Context()).Error(err)
return err return err

View file

@ -169,9 +169,21 @@ func createRoom(
asAPI appserviceAPI.AppServiceInternalAPI, asAPI appserviceAPI.AppServiceInternalAPI,
evTime time.Time, evTime time.Time,
) util.JSONResponse { ) util.JSONResponse {
_, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
if !cfg.Matrix.IsLocalServerName(userDomain) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)),
}
}
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain)
logger := util.GetLogger(ctx) logger := util.GetLogger(ctx)
userID := device.UserID userID := device.UserID
@ -314,7 +326,7 @@ func createRoom(
var roomAlias string var roomAlias string
if r.RoomAliasName != "" { if r.RoomAliasName != "" {
roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain)
// check it's free TODO: This races but is better than nothing // check it's free TODO: This races but is better than nothing
hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{
Alias: roomAlias, Alias: roomAlias,
@ -436,7 +448,7 @@ func createRoom(
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var ev *gomatrixserverlib.Event var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildEvent failed") util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -461,7 +473,7 @@ func createRoom(
inputs = append(inputs, roomserverAPI.InputRoomEvent{ inputs = append(inputs, roomserverAPI.InputRoomEvent{
Kind: roomserverAPI.KindNew, Kind: roomserverAPI.KindNew,
Event: event, Event: event,
Origin: cfg.Matrix.ServerName, Origin: userDomain,
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
@ -548,7 +560,7 @@ func createRoom(
Event: event, Event: event,
InviteRoomState: inviteStrippedState, InviteRoomState: inviteStrippedState,
RoomVersion: event.RoomVersion, RoomVersion: event.RoomVersion,
SendAsServer: string(cfg.Matrix.ServerName), SendAsServer: string(userDomain),
}, &inviteRes); err != nil { }, &inviteRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
return util.JSONResponse{ return util.JSONResponse{
@ -591,6 +603,7 @@ func createRoom(
// buildEvent fills out auth_events for the builder then builds the event // buildEvent fills out auth_events for the builder then builds the event
func buildEvent( func buildEvent(
builder *gomatrixserverlib.EventBuilder, builder *gomatrixserverlib.EventBuilder,
serverName gomatrixserverlib.ServerName,
provider gomatrixserverlib.AuthEventProvider, provider gomatrixserverlib.AuthEventProvider,
cfg *config.ClientAPI, cfg *config.ClientAPI,
evTime time.Time, evTime time.Time,
@ -606,7 +619,7 @@ func buildEvent(
} }
builder.AuthEvents = refs builder.AuthEvents = refs
event, err := builder.Build( event, err := builder.Build(
evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, evTime, serverName, cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey, roomVersion, cfg.Matrix.PrivateKey, roomVersion,
) )
if err != nil { if err != nil {

View file

@ -18,14 +18,15 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
type roomDirectoryResponse struct { type roomDirectoryResponse struct {
@ -75,7 +76,7 @@ func DirectoryRoom(
if res.RoomID == "" { if res.RoomID == "" {
// If we don't know it locally, do a federation query. // If we don't know it locally, do a federation query.
// But don't send the query to ourselves. // But don't send the query to ourselves.
if domain != cfg.Matrix.ServerName { if !cfg.Matrix.IsLocalServerName(domain) {
fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias)
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
@ -127,7 +128,7 @@ func SetLocalAlias(
} }
} }
if domain != cfg.Matrix.ServerName { if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Alias must be on local homeserver"), JSON: jsonerror.Forbidden("Alias must be on local homeserver"),
@ -318,3 +319,43 @@ func SetVisibility(
JSON: struct{}{}, JSON: struct{}{},
} }
} }
func SetVisibilityAS(
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
networkID, roomID string,
) util.JSONResponse {
if dev.AccountType != userapi.AccountTypeAppService {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Only appservice may use this endpoint"),
}
}
var v roomVisibility
// If the method is delete, we simply mark the visibility as private
if req.Method == http.MethodDelete {
v.Visibility = "private"
} else {
if reqErr := httputil.UnmarshalJSONRequest(req, &v); reqErr != nil {
return *reqErr
}
}
var publishRes roomserverAPI.PerformPublishResponse
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: v.Visibility,
NetworkID: networkID,
AppserviceID: dev.AppserviceID,
}, &publishRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -39,14 +39,17 @@ var (
) )
type PublicRoomReq struct { type PublicRoomReq struct {
Since string `json:"since,omitempty"` Since string `json:"since,omitempty"`
Limit int16 `json:"limit,omitempty"` Limit int64 `json:"limit,omitempty"`
Filter filter `json:"filter,omitempty"` Filter filter `json:"filter,omitempty"`
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
IncludeAllNetworks bool `json:"include_all_networks,omitempty"`
NetworkID string `json:"third_party_instance_id,omitempty"`
} }
type filter struct { type filter struct {
SearchTerms string `json:"generic_search_term,omitempty"` SearchTerms string `json:"generic_search_term,omitempty"`
RoomTypes []string `json:"room_types,omitempty"` // TODO: Implement filter on this
} }
// GetPostPublicRooms implements GET and POST /publicRooms // GetPostPublicRooms implements GET and POST /publicRooms
@ -61,9 +64,15 @@ func GetPostPublicRooms(
return *fillErr return *fillErr
} }
serverName := gomatrixserverlib.ServerName(request.Server) if request.IncludeAllNetworks && request.NetworkID != "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam("include_all_networks and third_party_instance_id can not be used together"),
}
}
if serverName != "" && serverName != cfg.Matrix.ServerName { serverName := gomatrixserverlib.ServerName(request.Server)
if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) {
res, err := federation.GetPublicRoomsFiltered( res, err := federation.GetPublicRoomsFiltered(
req.Context(), serverName, req.Context(), serverName,
int(request.Limit), request.Since, int(request.Limit), request.Since,
@ -98,7 +107,7 @@ func publicRooms(
response := gomatrixserverlib.RespPublicRooms{ response := gomatrixserverlib.RespPublicRooms{
Chunk: []gomatrixserverlib.PublicRoom{}, Chunk: []gomatrixserverlib.PublicRoom{},
} }
var limit int16 var limit int64
var offset int64 var offset int64
limit = request.Limit limit = request.Limit
if limit == 0 { if limit == 0 {
@ -115,7 +124,7 @@ func publicRooms(
var rooms []gomatrixserverlib.PublicRoom var rooms []gomatrixserverlib.PublicRoom
if request.Since == "" { if request.Since == "" {
rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider) rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider, request)
} else { } else {
rooms = getPublicRoomsFromCache() rooms = getPublicRoomsFromCache()
} }
@ -177,7 +186,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO
JSON: jsonerror.BadJSON("limit param is not a number"), JSON: jsonerror.BadJSON("limit param is not a number"),
} }
} }
request.Limit = int16(limit) request.Limit = int64(limit)
request.Since = httpReq.FormValue("since") request.Since = httpReq.FormValue("since")
request.Server = httpReq.FormValue("server") request.Server = httpReq.FormValue("server")
} else { } else {
@ -205,7 +214,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO
// limit=3&since=6 => G (prev='3', next='') // limit=3&since=6 => G (prev='3', next='')
// //
// A value of '-1' for prev/next indicates no position. // A value of '-1' for prev/next indicates no position.
func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) (subset []gomatrixserverlib.PublicRoom, prev, next int) { func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int64) (subset []gomatrixserverlib.PublicRoom, prev, next int) {
prev = -1 prev = -1
next = -1 next = -1
@ -231,6 +240,7 @@ func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) (
func refreshPublicRoomCache( func refreshPublicRoomCache(
ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider,
request PublicRoomReq,
) []gomatrixserverlib.PublicRoom { ) []gomatrixserverlib.PublicRoom {
cacheMu.Lock() cacheMu.Lock()
defer cacheMu.Unlock() defer cacheMu.Unlock()
@ -239,8 +249,17 @@ func refreshPublicRoomCache(
extraRooms = extRoomsProvider.Rooms() extraRooms = extRoomsProvider.Rooms()
} }
// TODO: this is only here to make Sytest happy, for now.
ns := strings.Split(request.NetworkID, "|")
if len(ns) == 2 {
request.NetworkID = ns[1]
}
var queryRes roomserverAPI.QueryPublishedRoomsResponse var queryRes roomserverAPI.QueryPublishedRoomsResponse
err := rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) err := rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{
NetworkID: request.NetworkID,
IncludeAllNetworks: request.IncludeAllNetworks,
}, &queryRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed")
return publicRoomsCache return publicRoomsCache

View file

@ -17,7 +17,7 @@ func TestSliceInto(t *testing.T) {
slice := []gomatrixserverlib.PublicRoom{ slice := []gomatrixserverlib.PublicRoom{
pubRoom("a"), pubRoom("b"), pubRoom("c"), pubRoom("d"), pubRoom("e"), pubRoom("f"), pubRoom("g"), pubRoom("a"), pubRoom("b"), pubRoom("c"), pubRoom("d"), pubRoom("e"), pubRoom("f"), pubRoom("g"),
} }
limit := int16(3) limit := int64(3)
testCases := []struct { testCases := []struct {
since int64 since int64
wantPrev int wantPrev int

View file

@ -68,7 +68,7 @@ func Login(
return *authErr return *authErr
} }
// make a device/access token // make a device/access token
authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) authErr2 := completeAuth(req.Context(), cfg.Matrix, userAPI, login, req.RemoteAddr, req.UserAgent())
cleanup(req.Context(), &authErr2) cleanup(req.Context(), &authErr2)
return authErr2 return authErr2
} }
@ -79,7 +79,7 @@ func Login(
} }
func completeAuth( func completeAuth(
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.ClientUserAPI, login *auth.Login, ctx context.Context, cfg *config.Global, userAPI userapi.ClientUserAPI, login *auth.Login,
ipAddr, userAgent string, ipAddr, userAgent string,
) util.JSONResponse { ) util.JSONResponse {
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()
@ -88,7 +88,7 @@ func completeAuth(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
localpart, err := userutil.ParseUsernameParam(login.Username(), &serverName) localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()

View file

@ -105,12 +105,13 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
serverName := device.UserDomain()
if err = roomserverAPI.SendEvents( if err = roomserverAPI.SendEvents(
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
cfg.Matrix.ServerName, serverName,
cfg.Matrix.ServerName, serverName,
nil, nil,
false, false,
); err != nil { ); err != nil {
@ -271,7 +272,7 @@ func sendInvite(
Event: event, Event: event,
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
RoomVersion: event.RoomVersion, RoomVersion: event.RoomVersion,
SendAsServer: string(cfg.Matrix.ServerName), SendAsServer: string(device.UserDomain()),
}, &inviteRes); err != nil { }, &inviteRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
return util.JSONResponse{ return util.JSONResponse{
@ -341,7 +342,7 @@ func loadProfile(
} }
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if cfg.Matrix.IsLocalServerName(serverName) {
profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI)
} else { } else {
profile = &authtypes.Profile{} profile = &authtypes.Profile{}

View file

@ -63,7 +63,7 @@ func CreateOpenIDToken(
JSON: openIDTokenResponse{ JSON: openIDTokenResponse{
AccessToken: response.Token.Token, AccessToken: response.Token.Token,
TokenType: "Bearer", TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName), MatrixServerName: string(device.UserDomain()),
ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
}, },
} }

View file

@ -113,12 +113,19 @@ func SetAvatarURL(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"),
}
}
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -129,8 +136,9 @@ func SetAvatarURL(
setRes := &userapi.PerformSetAvatarURLResponse{} setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart, Localpart: localpart,
AvatarURL: r.AvatarURL, ServerName: domain,
AvatarURL: r.AvatarURL,
}, setRes); err != nil { }, setRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -204,12 +212,19 @@ func SetDisplayName(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"),
}
}
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -221,6 +236,7 @@ func SetDisplayName(
profileRes := &userapi.PerformUpdateDisplayNameResponse{} profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
DisplayName: r.DisplayName, DisplayName: r.DisplayName,
}, profileRes) }, profileRes)
if err != nil { if err != nil {
@ -261,6 +277,12 @@ func updateProfile(
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError(), err
}
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
@ -276,7 +298,7 @@ func updateProfile(
return jsonerror.InternalServerError(), e return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
@ -298,7 +320,7 @@ func getProfile(
return nil, err return nil, err
} }
if domain != cfg.Matrix.ServerName { if !cfg.Matrix.IsLocalServerName(domain) {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") profile, fedErr := federation.LookupProfile(ctx, domain, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {

View file

@ -131,7 +131,8 @@ func SendRedaction(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} }
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"regexp" "regexp"
@ -336,6 +337,7 @@ func validateRecaptcha(
response string, response string,
clientip string, clientip string,
) *util.JSONResponse { ) *util.JSONResponse {
ip, _, _ := net.SplitHostPort(clientip)
if !cfg.RecaptchaEnabled { if !cfg.RecaptchaEnabled {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusConflict, Code: http.StatusConflict,
@ -355,7 +357,7 @@ func validateRecaptcha(
url.Values{ url.Values{
"secret": {cfg.RecaptchaPrivateKey}, "secret": {cfg.RecaptchaPrivateKey},
"response": {response}, "response": {response},
"remoteip": {clientip}, "remoteip": {ip},
}, },
) )
@ -412,7 +414,7 @@ func UserIDIsWithinApplicationServiceNamespace(
return false return false
} }
if domain != cfg.Matrix.ServerName { if !cfg.Matrix.IsLocalServerName(domain) {
return false return false
} }

View file

@ -163,6 +163,12 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}",
httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminDownloadState(req, cfg, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/fulltext/reindex", dendriteAdminRouter.Handle("/admin/fulltext/reindex",
httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminReindex(req, cfg, device, natsClient) return AdminReindex(req, cfg, device, natsClient)
@ -480,7 +486,7 @@ func Setup(
return GetVisibility(req, rsAPI, vars["roomID"]) return GetVisibility(req, rsAPI, vars["roomID"])
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support
v3mux.Handle("/directory/list/room/{roomID}", v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -490,6 +496,27 @@ func Setup(
return SetVisibility(req, rsAPI, device, vars["roomID"]) return SetVisibility(req, rsAPI, device, vars["roomID"])
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SetVisibilityAS(req, rsAPI, device, vars["networkID"], vars["roomID"])
}),
).Methods(http.MethodPut, http.MethodOptions)
// Undocumented endpoint
v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SetVisibilityAS(req, rsAPI, device, vars["networkID"], vars["roomID"])
}),
).Methods(http.MethodDelete, http.MethodOptions)
v3mux.Handle("/publicRooms", v3mux.Handle("/publicRooms",
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)

View file

@ -94,6 +94,7 @@ func SendEvent(
// create a mutex for the specific user in the specific room // create a mutex for the specific user in the specific room
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
userID := device.UserID userID := device.UserID
domain := device.UserDomain()
mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{}) mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{})
mutex.(*sync.Mutex).Lock() mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
@ -185,8 +186,8 @@ func SendEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion), e.Headered(verRes.RoomVersion),
}, },
cfg.Matrix.ServerName, domain,
cfg.Matrix.ServerName, domain,
txnAndSessionID, txnAndSessionID,
false, false,
); err != nil { ); err != nil {

View file

@ -215,7 +215,7 @@ func queryIDServerStoreInvite(
} }
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if cfg.Matrix.IsLocalServerName(serverName) {
res := &userapi.QueryProfileResponse{} res := &userapi.QueryProfileResponse{}
err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res)
if err != nil { if err != nil {

View file

@ -17,6 +17,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -24,23 +25,23 @@ import (
// usernameParam can either be a user ID or just the localpart/username. // usernameParam can either be a user ID or just the localpart/username.
// If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // If serverName is passed, it is verified against the domain obtained from usernameParam (if present)
// Returns error in case of invalid usernameParam. // Returns error in case of invalid usernameParam.
func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) { func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) {
localpart := usernameParam localpart := usernameParam
if strings.HasPrefix(usernameParam, "@") { if strings.HasPrefix(usernameParam, "@") {
lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam) lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam)
if err != nil { if err != nil {
return "", errors.New("invalid username") return "", "", errors.New("invalid username")
} }
if expectedServerName != nil && domain != *expectedServerName { if !cfg.IsLocalServerName(domain) {
return "", errors.New("user ID does not belong to this server") return "", "", errors.New("user ID does not belong to this server")
} }
localpart = lp return lp, domain, nil
} }
return localpart, nil return localpart, cfg.ServerName, nil
} }
// MakeUserID generates user ID from localpart & server name // MakeUserID generates user ID from localpart & server name

View file

@ -15,6 +15,7 @@ package userutil
import ( import (
"testing" "testing"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -28,7 +29,11 @@ var (
// TestGoodUserID checks that correct localpart is returned for a valid user ID. // TestGoodUserID checks that correct localpart is returned for a valid user ID.
func TestGoodUserID(t *testing.T) { func TestGoodUserID(t *testing.T) {
lp, err := ParseUsernameParam(goodUserID, &serverName) cfg := &config.Global{
ServerName: serverName,
}
lp, _, err := ParseUsernameParam(goodUserID, cfg)
if err != nil { if err != nil {
t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error()) t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error())
@ -41,7 +46,11 @@ func TestGoodUserID(t *testing.T) {
// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart.
func TestWithLocalpartOnly(t *testing.T) { func TestWithLocalpartOnly(t *testing.T) {
lp, err := ParseUsernameParam(localpart, &serverName) cfg := &config.Global{
ServerName: serverName,
}
lp, _, err := ParseUsernameParam(localpart, cfg)
if err != nil { if err != nil {
t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error()) t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error())
@ -54,7 +63,11 @@ func TestWithLocalpartOnly(t *testing.T) {
// TestIncorrectDomain checks for error when there's server name mismatch. // TestIncorrectDomain checks for error when there's server name mismatch.
func TestIncorrectDomain(t *testing.T) { func TestIncorrectDomain(t *testing.T) {
_, err := ParseUsernameParam(goodUserID, &invalidServerName) cfg := &config.Global{
ServerName: invalidServerName,
}
_, _, err := ParseUsernameParam(goodUserID, cfg)
if err == nil { if err == nil {
t.Error("Invalid Domain should return an error") t.Error("Invalid Domain should return an error")
@ -63,7 +76,11 @@ func TestIncorrectDomain(t *testing.T) {
// TestBadUserID checks that ParseUsernameParam fails for invalid user ID // TestBadUserID checks that ParseUsernameParam fails for invalid user ID
func TestBadUserID(t *testing.T) { func TestBadUserID(t *testing.T) {
_, err := ParseUsernameParam(badUserID, &serverName) cfg := &config.Global{
ServerName: serverName,
}
_, _, err := ParseUsernameParam(badUserID, cfg)
if err == nil { if err == nil {
t.Error("Illegal User ID should return an error") t.Error("Illegal User ID should return an error")

View file

@ -20,6 +20,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"regexp"
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -27,9 +28,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
ironwoodtypes "github.com/Arceliar/ironwood/types" ironwoodtypes "github.com/Arceliar/ironwood/types"
yggdrasilconfig "github.com/yggdrasil-network/yggdrasil-go/src/config" "github.com/yggdrasil-network/yggdrasil-go/src/core"
yggdrasilcore "github.com/yggdrasil-network/yggdrasil-go/src/core" yggdrasilcore "github.com/yggdrasil-network/yggdrasil-go/src/core"
yggdrasildefaults "github.com/yggdrasil-network/yggdrasil-go/src/defaults" "github.com/yggdrasil-network/yggdrasil-go/src/multicast"
yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast" yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast"
gologme "github.com/gologme/log" gologme "github.com/gologme/log"
@ -37,7 +38,6 @@ import (
type Node struct { type Node struct {
core *yggdrasilcore.Core core *yggdrasilcore.Core
config *yggdrasilconfig.NodeConfig
multicast *yggdrasilmulticast.Multicast multicast *yggdrasilmulticast.Multicast
log *gologme.Logger log *gologme.Logger
utpSocket *utp.Socket utpSocket *utp.Socket
@ -57,43 +57,52 @@ func (n *Node) DialerContext(ctx context.Context, _, address string) (net.Conn,
func Setup(sk ed25519.PrivateKey, instanceName, storageDirectory, peerURI, listenURI string) (*Node, error) { func Setup(sk ed25519.PrivateKey, instanceName, storageDirectory, peerURI, listenURI string) (*Node, error) {
n := &Node{ n := &Node{
core: &yggdrasilcore.Core{}, log: gologme.New(logrus.StandardLogger().Writer(), "", 0),
config: yggdrasildefaults.GenerateConfig(), incoming: make(chan net.Conn),
multicast: &yggdrasilmulticast.Multicast{},
log: gologme.New(logrus.StandardLogger().Writer(), "", 0),
incoming: make(chan net.Conn),
} }
options := []yggdrasilcore.SetupOption{
yggdrasilcore.AdminListenAddress("none"),
}
if listenURI != "" {
options = append(options, yggdrasilcore.ListenAddress(listenURI))
}
if peerURI != "" {
for _, uri := range strings.Split(peerURI, ",") {
options = append(options, yggdrasilcore.Peer{
URI: uri,
})
}
}
var err error
if n.core, err = yggdrasilcore.New(sk, options...); err != nil {
panic(err)
}
n.log.EnableLevel("error") n.log.EnableLevel("error")
n.log.EnableLevel("warn") n.log.EnableLevel("warn")
n.log.EnableLevel("info") n.log.EnableLevel("info")
n.core.SetLogger(n.log)
if n.utpSocket, err = utp.NewSocketFromPacketConnNoClose(n.core); err != nil { {
panic(err) var err error
options := []yggdrasilcore.SetupOption{}
if listenURI != "" {
options = append(options, yggdrasilcore.ListenAddress(listenURI))
}
if peerURI != "" {
for _, uri := range strings.Split(peerURI, ",") {
options = append(options, yggdrasilcore.Peer{
URI: uri,
})
}
}
if n.core, err = core.New(sk[:], n.log, options...); err != nil {
panic(err)
}
n.core.SetLogger(n.log)
if n.utpSocket, err = utp.NewSocketFromPacketConnNoClose(n.core); err != nil {
panic(err)
}
} }
if err = n.multicast.Init(n.core, n.config, n.log, nil); err != nil {
panic(err) // Setup the multicast module.
} {
if err = n.multicast.Start(); err != nil { var err error
panic(err) options := []multicast.SetupOption{
multicast.MulticastInterface{
Regex: regexp.MustCompile(".*"),
Beacon: true,
Listen: true,
Port: 0,
Priority: 0,
},
}
if n.multicast, err = multicast.New(n.core, n.log, options...); err != nil {
panic(err)
}
} }
n.log.Printf("Public key: %x", n.core.PublicKey()) n.log.Printf("Public key: %x", n.core.PublicKey())
@ -114,14 +123,7 @@ func (n *Node) DerivedServerName() string {
} }
func (n *Node) PrivateKey() ed25519.PrivateKey { func (n *Node) PrivateKey() ed25519.PrivateKey {
sk := make(ed25519.PrivateKey, ed25519.PrivateKeySize) return n.core.PrivateKey()
sb, err := hex.DecodeString(n.config.PrivateKey)
if err == nil {
copy(sk, sb[:])
} else {
panic(err)
}
return sk
} }
func (n *Node) PublicKey() ed25519.PublicKey { func (n *Node) PublicKey() ed25519.PublicKey {

View file

@ -179,7 +179,13 @@ client_api:
recaptcha_public_key: "" recaptcha_public_key: ""
recaptcha_private_key: "" recaptcha_private_key: ""
recaptcha_bypass_secret: "" recaptcha_bypass_secret: ""
recaptcha_siteverify_api: ""
# To use hcaptcha.com instead of ReCAPTCHA, set the following parameters, otherwise just keep them empty.
# recaptcha_siteverify_api: "https://hcaptcha.com/siteverify"
# recaptcha_api_js_url: "https://js.hcaptcha.com/1/api.js"
# recaptcha_form_field: "h-captcha-response"
# recaptcha_sitekey_class: "h-captcha"
# TURN server information that this homeserver should send to clients. # TURN server information that this homeserver should send to clients.
turn: turn:

View file

@ -175,7 +175,13 @@ client_api:
recaptcha_public_key: "" recaptcha_public_key: ""
recaptcha_private_key: "" recaptcha_private_key: ""
recaptcha_bypass_secret: "" recaptcha_bypass_secret: ""
recaptcha_siteverify_api: ""
# To use hcaptcha.com instead of ReCAPTCHA, set the following parameters, otherwise just keep them empty.
# recaptcha_siteverify_api: "https://hcaptcha.com/siteverify"
# recaptcha_api_js_url: "https://js.hcaptcha.com/1/api.js"
# recaptcha_form_field: "h-captcha-response"
# recaptcha_sitekey_class: "h-captcha"
# TURN server information that this homeserver should send to clients. # TURN server information that this homeserver should send to clients.
turn: turn:

View file

@ -35,14 +35,14 @@ import (
// KeyChangeConsumer consumes events that originate in key server. // KeyChangeConsumer consumes events that originate in key server.
type KeyChangeConsumer struct { type KeyChangeConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
rsAPI roomserverAPI.FederationRoomserverAPI rsAPI roomserverAPI.FederationRoomserverAPI
topic string topic string
} }
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
@ -55,14 +55,14 @@ func NewKeyChangeConsumer(
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
) *KeyChangeConsumer { ) *KeyChangeConsumer {
return &KeyChangeConsumer{ return &KeyChangeConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"), durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
queues: queues, queues: queues,
db: store, db: store,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
rsAPI: rsAPI, rsAPI: rsAPI,
} }
} }
@ -112,7 +112,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
logger.WithError(err).Error("Failed to extract domain from key change event") logger.WithError(err).Error("Failed to extract domain from key change event")
return true return true
} }
if originServerName != t.serverName { if !t.isLocalServerName(originServerName) {
return true return true
} }
@ -141,7 +141,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
// Pack the EDU and marshal it // Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDeviceListUpdate, Type: gomatrixserverlib.MDeviceListUpdate,
Origin: string(t.serverName), Origin: string(originServerName),
} }
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
UserID: m.UserID, UserID: m.UserID,
@ -159,7 +159,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
logger.Debugf("Sending device list update message to %q", destinations) logger.Debugf("Sending device list update message to %q", destinations)
err = t.queues.SendEDU(edu, t.serverName, destinations) err = t.queues.SendEDU(edu, originServerName, destinations)
return err == nil return err == nil
} }
@ -171,7 +171,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure")
return true return true
} }
if host != gomatrixserverlib.ServerName(t.serverName) { if !t.isLocalServerName(host) {
// Ignore any messages that didn't originate locally, otherwise we'll // Ignore any messages that didn't originate locally, otherwise we'll
// end up parroting information we received from other servers. // end up parroting information we received from other servers.
return true return true
@ -203,7 +203,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
// Pack the EDU and marshal it // Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: types.MSigningKeyUpdate, Type: types.MSigningKeyUpdate,
Origin: string(t.serverName), Origin: string(host),
} }
if edu.Content, err = json.Marshal(output); err != nil { if edu.Content, err = json.Marshal(output); err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
@ -212,7 +212,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
} }
logger.Debugf("Sending cross-signing update message to %q", destinations) logger.Debugf("Sending cross-signing update message to %q", destinations)
err = t.queues.SendEDU(edu, t.serverName, destinations) err = t.queues.SendEDU(edu, host, destinations)
return err == nil return err == nil
} }

View file

@ -38,7 +38,7 @@ type OutputPresenceConsumer struct {
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
topic string topic string
outboundPresenceEnabled bool outboundPresenceEnabled bool
} }
@ -56,7 +56,7 @@ func NewOutputPresenceConsumer(
jetstream: js, jetstream: js,
queues: queues, queues: queues,
db: store, db: store,
ServerName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
durable: cfg.Matrix.JetStream.Durable("FederationAPIPresenceConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPIPresenceConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
outboundPresenceEnabled: cfg.Matrix.Presence.EnableOutbound, outboundPresenceEnabled: cfg.Matrix.Presence.EnableOutbound,
@ -85,7 +85,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
log.WithError(err).WithField("user_id", userID).Error("failed to extract domain from receipt sender") log.WithError(err).WithField("user_id", userID).Error("failed to extract domain from receipt sender")
return true return true
} }
if serverName != t.ServerName { if !t.isLocalServerName(serverName) {
return true return true
} }
@ -127,7 +127,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MPresence, Type: gomatrixserverlib.MPresence,
Origin: string(t.ServerName), Origin: string(serverName),
} }
if edu.Content, err = json.Marshal(content); err != nil { if edu.Content, err = json.Marshal(content); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON") log.WithError(err).Error("failed to marshal EDU JSON")
@ -135,7 +135,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
log.Tracef("sending presence EDU to %d servers", len(joined)) log.Tracef("sending presence EDU to %d servers", len(joined))
if err = t.queues.SendEDU(edu, t.ServerName, joined); err != nil { if err = t.queues.SendEDU(edu, serverName, joined); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
return false return false
} }

View file

@ -34,13 +34,13 @@ import (
// OutputReceiptConsumer consumes events that originate in the clientapi. // OutputReceiptConsumer consumes events that originate in the clientapi.
type OutputReceiptConsumer struct { type OutputReceiptConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
topic string topic string
} }
// NewOutputReceiptConsumer creates a new OutputReceiptConsumer. Call Start() to begin consuming typing events. // NewOutputReceiptConsumer creates a new OutputReceiptConsumer. Call Start() to begin consuming typing events.
@ -52,13 +52,13 @@ func NewOutputReceiptConsumer(
store storage.Database, store storage.Database,
) *OutputReceiptConsumer { ) *OutputReceiptConsumer {
return &OutputReceiptConsumer{ return &OutputReceiptConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
queues: queues, queues: queues,
db: store, db: store,
ServerName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
} }
} }
@ -95,7 +95,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg)
log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender")
return true return true
} }
if receiptServerName != t.ServerName { if !t.isLocalServerName(receiptServerName) {
return true return true
} }
@ -134,14 +134,14 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg)
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MReceipt, Type: gomatrixserverlib.MReceipt,
Origin: string(t.ServerName), Origin: string(receiptServerName),
} }
if edu.Content, err = json.Marshal(content); err != nil { if edu.Content, err = json.Marshal(content); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON") log.WithError(err).Error("failed to marshal EDU JSON")
return true return true
} }
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { if err := t.queues.SendEDU(edu, receiptServerName, names); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
return false return false
} }

View file

@ -34,13 +34,13 @@ import (
// OutputSendToDeviceConsumer consumes events that originate in the clientapi. // OutputSendToDeviceConsumer consumes events that originate in the clientapi.
type OutputSendToDeviceConsumer struct { type OutputSendToDeviceConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
topic string topic string
} }
// NewOutputSendToDeviceConsumer creates a new OutputSendToDeviceConsumer. Call Start() to begin consuming send-to-device events. // NewOutputSendToDeviceConsumer creates a new OutputSendToDeviceConsumer. Call Start() to begin consuming send-to-device events.
@ -52,13 +52,13 @@ func NewOutputSendToDeviceConsumer(
store storage.Database, store storage.Database,
) *OutputSendToDeviceConsumer { ) *OutputSendToDeviceConsumer {
return &OutputSendToDeviceConsumer{ return &OutputSendToDeviceConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
queues: queues, queues: queues,
db: store, db: store,
ServerName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
} }
} }
@ -82,7 +82,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats
log.WithError(err).WithField("user_id", sender).Error("Failed to extract domain from send-to-device sender") log.WithError(err).WithField("user_id", sender).Error("Failed to extract domain from send-to-device sender")
return true return true
} }
if originServerName != t.ServerName { if !t.isLocalServerName(originServerName) {
return true return true
} }
// Extract the send-to-device event from msg. // Extract the send-to-device event from msg.
@ -101,14 +101,14 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats
} }
// The SyncAPI is already handling sendToDevice for the local server // The SyncAPI is already handling sendToDevice for the local server
if destServerName == t.ServerName { if t.isLocalServerName(destServerName) {
return true return true
} }
// Pack the EDU and marshal it // Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice, Type: gomatrixserverlib.MDirectToDevice,
Origin: string(t.ServerName), Origin: string(originServerName),
} }
tdm := gomatrixserverlib.ToDeviceMessage{ tdm := gomatrixserverlib.ToDeviceMessage{
Sender: ote.Sender, Sender: ote.Sender,
@ -127,7 +127,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats
} }
log.Debugf("Sending send-to-device message into %q destination queue", destServerName) log.Debugf("Sending send-to-device message into %q destination queue", destServerName)
if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { if err := t.queues.SendEDU(edu, originServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
return false return false
} }

View file

@ -31,13 +31,13 @@ import (
// OutputTypingConsumer consumes events that originate in the clientapi. // OutputTypingConsumer consumes events that originate in the clientapi.
type OutputTypingConsumer struct { type OutputTypingConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
topic string topic string
} }
// NewOutputTypingConsumer creates a new OutputTypingConsumer. Call Start() to begin consuming typing events. // NewOutputTypingConsumer creates a new OutputTypingConsumer. Call Start() to begin consuming typing events.
@ -49,13 +49,13 @@ func NewOutputTypingConsumer(
store storage.Database, store storage.Database,
) *OutputTypingConsumer { ) *OutputTypingConsumer {
return &OutputTypingConsumer{ return &OutputTypingConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
queues: queues, queues: queues,
db: store, db: store,
ServerName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"), durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
} }
} }
@ -87,7 +87,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg)
_ = msg.Ack() _ = msg.Ack()
return true return true
} }
if typingServerName != t.ServerName { if !t.isLocalServerName(typingServerName) {
return true return true
} }
@ -111,7 +111,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg)
log.WithError(err).Error("failed to marshal EDU JSON") log.WithError(err).Error("failed to marshal EDU JSON")
return true return true
} }
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { if err := t.queues.SendEDU(edu, typingServerName, names); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
return false return false
} }

View file

@ -69,7 +69,7 @@ func AddPublicRoutes(
TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), TopicDeviceListUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), TopicSigningKeyUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
ServerName: cfg.Matrix.ServerName, Config: cfg,
UserAPI: userAPI, UserAPI: userAPI,
} }
@ -107,7 +107,7 @@ func NewInternalAPI(
) api.FederationInternalAPI { ) api.FederationInternalAPI {
cfg := &base.Cfg.FederationAPI cfg := &base.Cfg.FederationAPI
federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.ServerName) federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db") logrus.WithError(err).Panic("failed to connect to federation sender db")
} }

View file

@ -87,6 +87,7 @@ func TestMain(m *testing.M) {
cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.JetStream.StoragePath = config.Path(d)
cfg.Global.KeyID = serverKeyID cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity cfg.Global.KeyValidityPeriod = s.validity
cfg.FederationAPI.KeyPerspectives = nil
f, err := os.CreateTemp(d, "federation_keys_test*.db") f, err := os.CreateTemp(d, "federation_keys_test*.db")
if err != nil { if err != nil {
return -1 return -1
@ -207,7 +208,6 @@ func TestRenewalBehaviour(t *testing.T) {
// happy at this point that the key that we already have is from the past // happy at this point that the key that we already have is from the past
// then repeating a key fetch should cause us to try and renew the key. // then repeating a key fetch should cause us to try and renew the key.
// If so, then the new key will end up in our cache. // If so, then the new key will end up in our cache.
serverC.renew() serverC.renew()
res, err = serverA.api.FetchKeys( res, err = serverA.api.FetchKeys(

View file

@ -164,6 +164,7 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.FederationAPI.PreferDirectFetch = true base.Cfg.FederationAPI.PreferDirectFetch = true
base.Cfg.FederationAPI.KeyPerspectives = nil
defer close() defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)

View file

@ -44,7 +44,7 @@ func (a *FederationInternalAPI) ClaimKeys(
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, s, oneTimeKeys)
}) })
if err != nil { if err != nil {

View file

@ -99,7 +99,7 @@ func (s *FederationInternalAPI) handleLocalKeys(
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) { ) {
for req := range requests { for req := range requests {
if req.ServerName != s.cfg.Matrix.ServerName { if !s.cfg.Matrix.IsLocalServerName(req.ServerName) {
continue continue
} }
if req.KeyID == s.cfg.Matrix.KeyID { if req.KeyID == s.cfg.Matrix.KeyID {

View file

@ -77,7 +77,7 @@ func (r *FederationInternalAPI) PerformJoin(
seenSet := make(map[gomatrixserverlib.ServerName]bool) seenSet := make(map[gomatrixserverlib.ServerName]bool)
var uniqueList []gomatrixserverlib.ServerName var uniqueList []gomatrixserverlib.ServerName
for _, srv := range request.ServerNames { for _, srv := range request.ServerNames {
if seenSet[srv] || srv == r.cfg.Matrix.ServerName { if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
continue continue
} }
seenSet[srv] = true seenSet[srv] = true

View file

@ -25,6 +25,7 @@ import (
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -39,7 +40,7 @@ type SyncAPIProducer struct {
TopicDeviceListUpdate string TopicDeviceListUpdate string
TopicSigningKeyUpdate string TopicSigningKeyUpdate string
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
ServerName gomatrixserverlib.ServerName Config *config.FederationAPI
UserAPI userapi.UserInternalAPI UserAPI userapi.UserInternalAPI
} }
@ -77,7 +78,7 @@ func (p *SyncAPIProducer) SendToDevice(
// device. If the event isn't targeted locally then we can't expand the // device. If the event isn't targeted locally then we can't expand the
// wildcard as we don't know about the remote devices, so instead we leave it // wildcard as we don't know about the remote devices, so instead we leave it
// as-is, so that the federation sender can send it on with the wildcard intact. // as-is, so that the federation sender can send it on with the wildcard intact.
if domain == p.ServerName && deviceID == "*" { if p.Config.Matrix.IsLocalServerName(domain) && deviceID == "*" {
var res userapi.QueryDevicesResponse var res userapi.QueryDevicesResponse
err = p.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ err = p.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{
UserID: userID, UserID: userID,

View file

@ -76,21 +76,25 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
return return
} }
// If there's room in memory to hold the event then add it to the // Check if the destination is blacklisted. If it isn't then wake
// list. // up the queue.
oq.pendingMutex.Lock() if !oq.statistics.Blacklisted() {
if len(oq.pendingPDUs) < maxPDUsInMemory { // If there's room in memory to hold the event then add it to the
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ // list.
pdu: event, oq.pendingMutex.Lock()
receipt: receipt, if len(oq.pendingPDUs) < maxPDUsInMemory {
}) oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
} else { pdu: event,
oq.overflowed.Store(true) receipt: receipt,
} })
oq.pendingMutex.Unlock() } else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
if !oq.backingOff.Load() { if !oq.backingOff.Load() {
oq.wakeQueueAndNotify() oq.wakeQueueAndNotify()
}
} }
} }
@ -103,21 +107,25 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
return return
} }
// If there's room in memory to hold the event then add it to the // Check if the destination is blacklisted. If it isn't then wake
// list. // up the queue.
oq.pendingMutex.Lock() if !oq.statistics.Blacklisted() {
if len(oq.pendingEDUs) < maxEDUsInMemory { // If there's room in memory to hold the event then add it to the
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ // list.
edu: event, oq.pendingMutex.Lock()
receipt: receipt, if len(oq.pendingEDUs) < maxEDUsInMemory {
}) oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
} else { edu: event,
oq.overflowed.Store(true) receipt: receipt,
} })
oq.pendingMutex.Unlock() } else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
if !oq.backingOff.Load() { if !oq.backingOff.Load() {
oq.wakeQueueAndNotify() oq.wakeQueueAndNotify()
}
} }
} }

View file

@ -247,9 +247,10 @@ func (oqs *OutgoingQueues) SendEvent(
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }
destQueues := make([]*destinationQueue, 0, len(destmap))
for destination := range destmap { for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { if queue := oqs.getQueue(destination); queue != nil {
queue.sendEvent(ev, nid) destQueues = append(destQueues, queue)
} else { } else {
delete(destmap, destination) delete(destmap, destination)
} }
@ -267,6 +268,14 @@ func (oqs *OutgoingQueues) SendEvent(
return err return err
} }
// NOTE : PDUs should be associated with destinations before sending
// them, otherwise this is technically a race.
// If the send completes before they are associated then they won't
// get properly cleaned up in the database.
for _, queue := range destQueues {
queue.sendEvent(ev, nid)
}
return nil return nil
} }
@ -335,20 +344,21 @@ func (oqs *OutgoingQueues) SendEDU(
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }
destQueues := make([]*destinationQueue, 0, len(destmap))
for destination := range destmap { for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { if queue := oqs.getQueue(destination); queue != nil {
queue.sendEDU(e, nid) destQueues = append(destQueues, queue)
} else { } else {
delete(destmap, destination) delete(destmap, destination)
} }
} }
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // these destination queues. We'll then be able to retrieve the PDU
// later. // later.
if err := oqs.db.AssociateEDUWithDestinations( if err := oqs.db.AssociateEDUWithDestinations(
oqs.process.Context(), oqs.process.Context(),
destmap, // the destination server name destmap, // the destination server names
nid, // NIDs from federationapi_queue_json table nid, // NIDs from federationapi_queue_json table
e.Type, e.Type,
nil, // this will use the default expireEDUTypes map nil, // this will use the default expireEDUTypes map
@ -357,6 +367,14 @@ func (oqs *OutgoingQueues) SendEDU(
return err return err
} }
// NOTE : EDUs should be associated with destinations before sending
// them, otherwise this is technically a race.
// If the send completes before they are associated then they won't
// get properly cleaned up in the database.
for _, queue := range destQueues {
queue.sendEDU(e, nid)
}
return nil return nil
} }

View file

@ -47,7 +47,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase
connStr, dbClose := test.PrepareDBConnectionString(t, dbType) connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(b, &config.DatabaseOptions{ db, err := storage.NewDatabase(b, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, b.Caches, b.Cfg.Global.ServerName) }, b.Caches, b.Cfg.Global.IsLocalServerName)
if err != nil { if err != nil {
t.Fatalf("NewDatabase returned %s", err) t.Fatalf("NewDatabase returned %s", err)
} }

View file

@ -2,24 +2,29 @@ package routing
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
type PublicRoomReq struct { type PublicRoomReq struct {
Since string `json:"since,omitempty"` Since string `json:"since,omitempty"`
Limit int16 `json:"limit,omitempty"` Limit int16 `json:"limit,omitempty"`
Filter filter `json:"filter,omitempty"` Filter filter `json:"filter,omitempty"`
IncludeAllNetworks bool `json:"include_all_networks,omitempty"`
NetworkID string `json:"third_party_instance_id,omitempty"`
} }
type filter struct { type filter struct {
SearchTerms string `json:"generic_search_term,omitempty"` SearchTerms string `json:"generic_search_term,omitempty"`
RoomTypes []string `json:"room_types,omitempty"`
} }
// GetPostPublicRooms implements GET and POST /publicRooms // GetPostPublicRooms implements GET and POST /publicRooms
@ -57,8 +62,14 @@ func publicRooms(
return nil, err return nil, err
} }
if request.IncludeAllNetworks && request.NetworkID != "" {
return nil, fmt.Errorf("include_all_networks and third_party_instance_id can not be used together")
}
var queryRes roomserverAPI.QueryPublishedRoomsResponse var queryRes roomserverAPI.QueryPublishedRoomsResponse
err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{
NetworkID: request.NetworkID,
}, &queryRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed")
return nil, err return nil, err

View file

@ -124,7 +124,7 @@ func Setup(
mu := internal.NewMutexByRoom() mu := internal.NewMutexByRoom()
v1fedmux.Handle("/send/{txnID}", MakeFedAPI( v1fedmux.Handle("/send/{txnID}", MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys, wakeup, "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return Send( return Send(
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
@ -134,7 +134,7 @@ func Setup(
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys, wakeup, "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -150,7 +150,7 @@ func Setup(
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)
v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys, wakeup, "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -172,7 +172,7 @@ func Setup(
)).Methods(http.MethodPost, http.MethodOptions) )).Methods(http.MethodPost, http.MethodOptions)
v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI(
"exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return ExchangeThirdPartyInvite( return ExchangeThirdPartyInvite(
httpReq, request, vars["roomID"], rsAPI, cfg, federation, httpReq, request, vars["roomID"], rsAPI, cfg, federation,
@ -181,7 +181,7 @@ func Setup(
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/event/{eventID}", MakeFedAPI( v1fedmux.Handle("/event/{eventID}", MakeFedAPI(
"federation_get_event", cfg.Matrix.ServerName, keys, wakeup, "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return GetEvent( return GetEvent(
httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName,
@ -190,7 +190,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/state/{roomID}", MakeFedAPI( v1fedmux.Handle("/state/{roomID}", MakeFedAPI(
"federation_get_state", cfg.Matrix.ServerName, keys, wakeup, "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -205,7 +205,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI(
"federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -220,7 +220,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -235,7 +235,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/query/directory", MakeFedAPI( v1fedmux.Handle("/query/directory", MakeFedAPI(
"federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return RoomAliasToID( return RoomAliasToID(
httpReq, federation, cfg, rsAPI, fsAPI, httpReq, federation, cfg, rsAPI, fsAPI,
@ -244,7 +244,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/query/profile", MakeFedAPI( v1fedmux.Handle("/query/profile", MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return GetProfile( return GetProfile(
httpReq, userAPI, cfg, httpReq, userAPI, cfg,
@ -253,7 +253,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI(
"federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return GetUserDevices( return GetUserDevices(
httpReq, keyAPI, vars["userID"], httpReq, keyAPI, vars["userID"],
@ -263,7 +263,7 @@ func Setup(
if mscCfg.Enabled("msc2444") { if mscCfg.Enabled("msc2444") {
v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI(
"federation_peek", cfg.Matrix.ServerName, keys, wakeup, "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -294,7 +294,7 @@ func Setup(
} }
v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI(
"federation_make_join", cfg.Matrix.ServerName, keys, wakeup, "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -325,7 +325,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys, wakeup, "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -357,7 +357,7 @@ func Setup(
)).Methods(http.MethodPut) )).Methods(http.MethodPut)
v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys, wakeup, "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -374,7 +374,7 @@ func Setup(
)).Methods(http.MethodPut) )).Methods(http.MethodPut)
v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI(
"federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -391,7 +391,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI(
"federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -423,7 +423,7 @@ func Setup(
)).Methods(http.MethodPut) )).Methods(http.MethodPut)
v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI(
"federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -447,7 +447,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI(
"federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -460,7 +460,7 @@ func Setup(
)).Methods(http.MethodPost) )).Methods(http.MethodPost)
v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI(
"federation_backfill", cfg.Matrix.ServerName, keys, wakeup, "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{ return util.JSONResponse{
@ -479,14 +479,14 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost) ).Methods(http.MethodGet, http.MethodPost)
v1fedmux.Handle("/user/keys/claim", MakeFedAPI( v1fedmux.Handle("/user/keys/claim", MakeFedAPI(
"federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
}, },
)).Methods(http.MethodPost) )).Methods(http.MethodPost)
v1fedmux.Handle("/user/keys/query", MakeFedAPI( v1fedmux.Handle("/user/keys/query", MakeFedAPI(
"federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
}, },
@ -525,15 +525,15 @@ func ErrorIfLocalServerNotInRoom(
// MakeFedAPI makes an http.Handler that checks matrix federation authentication. // MakeFedAPI makes an http.Handler that checks matrix federation authentication.
func MakeFedAPI( func MakeFedAPI(
metricsName string, metricsName string, serverName gomatrixserverlib.ServerName,
serverName gomatrixserverlib.ServerName, isLocalServerName func(gomatrixserverlib.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier, keyRing gomatrixserverlib.JSONVerifier,
wakeup *FederationWakeups, wakeup *FederationWakeups,
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
) http.Handler { ) http.Handler {
h := func(req *http.Request) util.JSONResponse { h := func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), serverName, keyRing, req, time.Now(), serverName, isLocalServerName, keyRing,
) )
if fedReq == nil { if fedReq == nil {
return errResp return errResp

View file

@ -36,7 +36,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
@ -96,7 +96,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
ServerName: serverName, IsLocalServerName: isLocalServerName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,

View file

@ -29,7 +29,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache Cache caching.FederationCache
Writer sqlutil.Writer Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs FederationQueuePDUs tables.FederationQueuePDUs
@ -124,7 +124,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string,
} }
if excludeSelf { if excludeSelf {
for i, server := range servers { for i, server := range servers {
if server == d.ServerName { if d.IsLocalServerName(server) {
servers = append(servers[:i], servers[i+1:]...) servers = append(servers[:i], servers[i+1:]...)
} }
} }

View file

@ -35,7 +35,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
ServerName: serverName, IsLocalServerName: isLocalServerName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,

View file

@ -29,12 +29,12 @@ import (
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, cache, serverName) return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, cache, serverName) return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -19,7 +19,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat
connStr, dbClose := test.PrepareDBConnectionString(t, dbType) connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(b, &config.DatabaseOptions{ db, err := storage.NewDatabase(b, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, b.Caches, b.Cfg.Global.ServerName) }, b.Caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" })
if err != nil { if err != nil {
t.Fatalf("NewDatabase returned %s", err) t.Fatalf("NewDatabase returned %s", err)
} }

8
go.mod
View file

@ -1,7 +1,7 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/Arceliar/ironwood v0.0.0-20220903132624-ee60c16bcfcf github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/MFAshby/stdemuxerhook v1.0.0 github.com/MFAshby/stdemuxerhook v1.0.0
@ -22,8 +22,8 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/nats-io/nats-server/v2 v2.9.3 github.com/nats-io/nats-server/v2 v2.9.3
github.com/nats-io/nats.go v1.18.0 github.com/nats-io/nats.go v1.18.0
@ -40,7 +40,7 @@ require (
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-client-go v2.30.0+incompatible
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.5-0.20220901155642-4f2abece817c github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.1.0 golang.org/x/crypto v0.1.0
golang.org/x/image v0.0.0-20220902085622-e7cb96979f69 golang.org/x/image v0.0.0-20220902085622-e7cb96979f69

16
go.sum
View file

@ -38,8 +38,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/Arceliar/ironwood v0.0.0-20220903132624-ee60c16bcfcf h1:kjPkmDHUTWUma/4tqDl208bOk3jsUEqOJA6TsMZo5Jk= github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2 h1:Usab30pNT2i/vZvpXcN9uOr5IO1RZPcUqoGH0DIAPnU=
github.com/Arceliar/ironwood v0.0.0-20220903132624-ee60c16bcfcf/go.mod h1:RP72rucOFm5udrnEzTmIWLRVGQiV/fSUAQXJ0RST/nk= github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2/go.mod h1:RP72rucOFm5udrnEzTmIWLRVGQiV/fSUAQXJ0RST/nk=
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 h1:WndgpSW13S32VLQ3ugUxx2EnnWmgba1kCqPkd4Gk1yQ= github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 h1:WndgpSW13S32VLQ3ugUxx2EnnWmgba1kCqPkd4Gk1yQ=
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
@ -387,10 +387,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74 h1:I4LUlFqxZ72m3s9wIvUIV2FpprsxW28dO/0lAgepCZY=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
@ -595,8 +595,8 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yggdrasil-network/yggdrasil-go v0.4.5-0.20220901155642-4f2abece817c h1:/cTmA6pV2Z20BT/FGSmnb5BmJ8eRbDP0HbCB5IO1aKw= github.com/yggdrasil-network/yggdrasil-go v0.4.6 h1:GALUDV9QPz/5FVkbazpkTc9EABHufA556JwUJZr41j4=
github.com/yggdrasil-network/yggdrasil-go v0.4.5-0.20220901155642-4f2abece817c/go.mod h1:cIwhYwX9yT9Bcei59O0oOBSaj+kQP+9aVQUMWHh5R00= github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA9QgCS6Bu+9V+1VkWM84wpw=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 10 VersionMinor = 10
VersionPatch = 4 VersionPatch = 5
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
func (a *KeyInternalAPI) claimRemoteKeys( func (a *KeyInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) { ) {
resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) var wg sync.WaitGroup // Wait for fan-out goroutines to finish
// allows us to wait until all federation servers have been poked var mu sync.Mutex // Protects the response struct
var wg sync.WaitGroup var claimed int // Number of keys claimed in total
wg.Add(len(domainToDeviceKeys)) var failures int // Number of servers we failed to ask
// mutex for failures
var failMu sync.Mutex util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") wg.Add(len(domainToDeviceKeys))
// fan out
for d, k := range domainToDeviceKeys { for d, k := range domainToDeviceKeys {
go func(domain string, keysToClaim map[string]map[string]string) { go func(domain string, keysToClaim map[string]map[string]string) {
defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout) fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock()
defer mu.Unlock()
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
failMu.Lock()
res.Failures[domain] = map[string]interface{}{ res.Failures[domain] = map[string]interface{}{
"message": err.Error(), "message": err.Error(),
} }
failMu.Unlock() failures++
return return
} }
resultCh <- &claimKeyRes
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
for deviceID, keys := range deviceIDToKeys {
res.OneTimeKeys[userID][deviceID] = keys
claimed += len(keys)
}
}
}(d, k) }(d, k)
} }
// Close the result channel when the goroutines have quit so the for .. range exits wg.Wait()
go func() { util.GetLogger(ctx).WithFields(logrus.Fields{
wg.Wait() "num_keys": claimed,
close(resultCh) "num_failures": failures,
}() }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
keysClaimed := 0
for result := range resultCh {
for userID, nest := range result.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
for deviceID, nest2 := range nest {
res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
for keyIDWithAlgo, otk := range nest2 {
keyJSON, err := json.Marshal(otk)
if err != nil {
continue
}
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
keysClaimed++
}
}
}
}
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
} }
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {

View file

@ -150,6 +150,7 @@ type ClientRoomserverAPI interface {
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error

View file

@ -131,6 +131,16 @@ func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
return err return err
} }
func (t *RoomserverInternalAPITrace) PerformAdminDownloadState(
ctx context.Context,
req *PerformAdminDownloadStateRequest,
res *PerformAdminDownloadStateResponse,
) error {
err := t.Impl.PerformAdminDownloadState(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminDownloadState req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek( func (t *RoomserverInternalAPITrace) PerformInboundPeek(
ctx context.Context, ctx context.Context,
req *PerformInboundPeekRequest, req *PerformInboundPeekRequest,

View file

@ -168,8 +168,10 @@ type PerformBackfillResponse struct {
} }
type PerformPublishRequest struct { type PerformPublishRequest struct {
RoomID string RoomID string
Visibility string Visibility string
AppserviceID string
NetworkID string
} }
type PerformPublishResponse struct { type PerformPublishResponse struct {
@ -235,3 +237,13 @@ type PerformAdminEvacuateUserResponse struct {
Affected []string `json:"affected"` Affected []string `json:"affected"`
Error *PerformError Error *PerformError
} }
type PerformAdminDownloadStateRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
ServerName gomatrixserverlib.ServerName `json:"server_name"`
}
type PerformAdminDownloadStateResponse struct {
Error *PerformError `json:"error,omitempty"`
}

View file

@ -21,8 +21,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState
@ -257,7 +258,9 @@ type QueryRoomVersionForRoomResponse struct {
type QueryPublishedRoomsRequest struct { type QueryPublishedRoomsRequest struct {
// Optional. If specified, returns whether this room is published or not. // Optional. If specified, returns whether this room is published or not.
RoomID string RoomID string
NetworkID string
IncludeAllNetworks bool
} }
type QueryPublishedRoomsResponse struct { type QueryPublishedRoomsResponse struct {

View file

@ -117,6 +117,11 @@ func (r *Admin) PerformAdminEvacuateRoom(
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender)
if err != nil {
continue
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -146,8 +151,8 @@ func (r *Admin) PerformAdminEvacuateRoom(
inputEvents = append(inputEvents, api.InputRoomEvent{ inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: r.Cfg.Matrix.ServerName, Origin: senderDomain,
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(senderDomain),
}) })
res.Affected = append(res.Affected, stateKey) res.Affected = append(res.Affected, stateKey)
prevEvents = []gomatrixserverlib.EventReference{ prevEvents = []gomatrixserverlib.EventReference{
@ -176,7 +181,7 @@ func (r *Admin) PerformAdminEvacuateUser(
} }
return nil return nil
} }
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint", Msg: "Can only evacuate local users using this endpoint",
@ -231,3 +236,145 @@ func (r *Admin) PerformAdminEvacuateUser(
} }
return nil return nil
} }
func (r *Admin) PerformAdminDownloadState(
ctx context.Context,
req *api.PerformAdminDownloadStateRequest,
res *api.PerformAdminDownloadStateResponse,
) error {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return nil
}
if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("room %q not found", req.RoomID),
}
return nil
}
fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.LatestEventIDs: %s", err),
}
return nil
}
authEventMap := map[string]*gomatrixserverlib.Event{}
stateEventMap := map[string]*gomatrixserverlib.Event{}
for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.RespState
state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err),
}
return nil
}
for _, authEvent := range state.AuthEvents.UntrustedEvents(roomInfo.RoomVersion) {
if err = authEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil {
continue
}
authEventMap[authEvent.EventID()] = authEvent
}
for _, stateEvent := range state.StateEvents.UntrustedEvents(roomInfo.RoomVersion) {
if err = stateEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil {
continue
}
stateEventMap[stateEvent.EventID()] = stateEvent
}
}
authEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(authEventMap))
stateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEventMap))
stateIDs := make([]string, 0, len(stateEventMap))
for _, authEvent := range authEventMap {
authEvents = append(authEvents, authEvent.Headered(roomInfo.RoomVersion))
}
for _, stateEvent := range stateEventMap {
stateEvents = append(stateEvents, stateEvent.Headered(roomInfo.RoomVersion))
stateIDs = append(stateIDs, stateEvent.EventID())
}
builder := &gomatrixserverlib.EventBuilder{
Type: "org.matrix.dendrite.state_download",
Sender: req.UserID,
RoomID: req.RoomID,
Content: gomatrixserverlib.RawJSON("{}"),
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return nil
}
queryRes := &api.QueryLatestEventsAndStateResponse{
RoomExists: true,
RoomVersion: roomInfo.RoomVersion,
LatestEvents: fwdExtremities,
StateEvents: stateEvents,
Depth: depth,
}
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return nil
}
inputReq := &api.InputRoomEventsRequest{
Asynchronous: false,
}
inputRes := &api.InputRoomEventsResponse{}
for _, authEvent := range append(authEvents, stateEvents...) {
inputReq.InputRoomEvents = append(inputReq.InputRoomEvents, api.InputRoomEvent{
Kind: api.KindOutlier,
Event: authEvent,
})
}
inputReq.InputRoomEvents = append(inputReq.InputRoomEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: ev,
Origin: r.Cfg.Matrix.ServerName,
HasState: true,
StateEventIDs: stateIDs,
SendAsServer: string(r.Cfg.Matrix.ServerName),
})
if err := r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.InputRoomEvents: %s", err),
}
return nil
}
if inputRes.ErrMsg != "" {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: inputRes.ErrMsg,
}
}
return nil
}

View file

@ -70,8 +70,8 @@ func (r *Inviter) PerformInvite(
} }
return nil, nil return nil, nil
} }
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain)
isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain)
if !isOriginLocal && !isTargetLocal { if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -92,7 +92,7 @@ func (r *Joiner) performJoin(
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
} }
} }
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", "", &rsAPI.PerformError{ return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorBadRequest, Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
@ -124,7 +124,7 @@ func (r *Joiner) performJoinRoomByAlias(
// Check if this alias matches our own server configuration. If it // Check if this alias matches our own server configuration. If it
// doesn't then we'll need to try a federated join. // doesn't then we'll need to try a federated join.
var roomID string var roomID string
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
// The alias isn't owned by us, so we will need to try joining using // The alias isn't owned by us, so we will need to try joining using
// a remote server. // a remote server.
dirReq := fsAPI.PerformDirectoryLookupRequest{ dirReq := fsAPI.PerformDirectoryLookupRequest{
@ -172,7 +172,7 @@ func (r *Joiner) performJoinRoomByID(
// The original client request ?server_name=... may include this HS so filter that out so we // The original client request ?server_name=... may include this HS so filter that out so we
// don't attempt to make_join with ourselves // don't attempt to make_join with ourselves
for i := 0; i < len(req.ServerNames); i++ { for i := 0; i < len(req.ServerNames); i++ {
if req.ServerNames[i] == r.Cfg.Matrix.ServerName { if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) {
// delete this entry // delete this entry
req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...) req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...)
i-- i--
@ -191,12 +191,19 @@ func (r *Joiner) performJoinRoomByID(
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add // possible candidate for finding the room via federation. Add
// it to the list of servers to try. // it to the list of servers to try.
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, domain)
} }
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err),
}
}
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{
Type: gomatrixserverlib.MRoomMember, Type: gomatrixserverlib.MRoomMember,
Sender: userID, Sender: userID,
@ -247,7 +254,7 @@ func (r *Joiner) performJoinRoomByID(
// If we were invited by someone from another server then we can // If we were invited by someone from another server then we can
// assume they are in the room so we can join via them. // assume they are in the room so we can join via them.
if inviterDomain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) {
req.ServerNames = append(req.ServerNames, inviterDomain) req.ServerNames = append(req.ServerNames, inviterDomain)
forceFederatedJoin = true forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON())) memberEvent := gjson.Parse(string(inviteEvent.JSON()))
@ -300,7 +307,7 @@ func (r *Joiner) performJoinRoomByID(
{ {
Kind: rsAPI.KindNew, Kind: rsAPI.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event.Headered(buildRes.RoomVersion),
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(userDomain),
}, },
}, },
} }
@ -323,7 +330,7 @@ func (r *Joiner) performJoinRoomByID(
// The room doesn't exist locally. If the room ID looks like it should // The room doesn't exist locally. If the room ID looks like it should
// be ours then this probably means that we've nuked our database at // be ours then this probably means that we've nuked our database at
// some point. // some point.
if domain == r.Cfg.Matrix.ServerName { if r.Cfg.Matrix.IsLocalServerName(domain) {
// If there are no more server names to try then give up here. // If there are no more server names to try then give up here.
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
@ -348,7 +355,7 @@ func (r *Joiner) performJoinRoomByID(
// it will have been overwritten with a room ID by performJoinRoomByAlias. // it will have been overwritten with a room ID by performJoinRoomByAlias.
// We should now include this in the response so that the CS API can // We should now include this in the response so that the CS API can
// return the right room ID. // return the right room ID.
return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil return req.RoomIDOrAlias, userDomain, nil
} }
func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) performFederatedJoinRoomByID(

View file

@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave(
if err != nil { if err != nil {
return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)
} }
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
} }
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
@ -85,7 +85,7 @@ func (r *Leaver) performLeaveRoomByID(
if serr != nil { if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", senderUser) return nil, fmt.Errorf("sender %q is invalid", senderUser)
} }
if senderDomain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
@ -186,7 +186,7 @@ func (r *Leaver) performLeaveRoomByID(
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event.Headered(buildRes.RoomVersion),
Origin: senderDomain, Origin: senderDomain,
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(senderDomain),
}, },
}, },
} }

View file

@ -72,7 +72,7 @@ func (r *Peeker) performPeek(
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
} }
} }
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", &api.PerformError{ return "", &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
@ -104,7 +104,7 @@ func (r *Peeker) performPeekRoomByAlias(
// Check if this alias matches our own server configuration. If it // Check if this alias matches our own server configuration. If it
// doesn't then we'll need to try a federated peek. // doesn't then we'll need to try a federated peek.
var roomID string var roomID string
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
// The alias isn't owned by us, so we will need to try peeking using // The alias isn't owned by us, so we will need to try peeking using
// a remote server. // a remote server.
dirReq := fsAPI.PerformDirectoryLookupRequest{ dirReq := fsAPI.PerformDirectoryLookupRequest{
@ -154,7 +154,7 @@ func (r *Peeker) performPeekRoomByID(
// handle federated peeks // handle federated peeks
// FIXME: don't create an outbound peek if we already have one going. // FIXME: don't create an outbound peek if we already have one going.
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add // possible candidate for finding the room via federation. Add
// it to the list of servers to try. // it to the list of servers to try.

View file

@ -30,7 +30,7 @@ func (r *Publisher) PerformPublish(
req *api.PerformPublishRequest, req *api.PerformPublishRequest,
res *api.PerformPublishResponse, res *api.PerformPublishResponse,
) error { ) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") err := r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public")
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Msg: err.Error(), Msg: err.Error(),

View file

@ -67,7 +67,7 @@ func (r *Unpeeker) performUnpeek(
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
} }
} }
if domain != r.Cfg.Matrix.ServerName { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return &api.PerformError{ return &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),

View file

@ -60,6 +60,13 @@ func (r *Upgrader) performRoomUpgrade(
) (string, *api.PerformError) { ) (string, *api.PerformError) {
roomID := req.RoomID roomID := req.RoomID
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return "", &api.PerformError{
Code: api.PerformErrorNotAllowed,
Msg: "Error validating the user ID",
}
}
evTime := time.Now() evTime := time.Now()
// Return an immediate error if the room does not exist // Return an immediate error if the room does not exist
@ -80,7 +87,7 @@ func (r *Upgrader) performRoomUpgrade(
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName) newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain)
// Get the existing room state for the old room. // Get the existing room state for the old room.
oldRoomReq := &api.QueryLatestEventsAndStateRequest{ oldRoomReq := &api.QueryLatestEventsAndStateRequest{
@ -107,12 +114,12 @@ func (r *Upgrader) performRoomUpgrade(
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
// 5. Send the tombstone event to the old room // 5. Send the tombstone event to the old room
if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent, string(r.Cfg.Matrix.ServerName)); pErr != nil { if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil {
return "", pErr return "", pErr
} }
@ -122,7 +129,7 @@ func (r *Upgrader) performRoomUpgrade(
} }
// If the old room had a canonical alias event, it should be deleted in the old room // If the old room had a canonical alias event, it should be deleted in the old room
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, roomID); pErr != nil { if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil {
return "", pErr return "", pErr
} }
@ -132,7 +139,7 @@ func (r *Upgrader) performRoomUpgrade(
} }
// 6. Restrict power levels in the old room // 6. Restrict power levels in the old room
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, roomID); pErr != nil { if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil {
return "", pErr return "", pErr
} }
@ -154,7 +161,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
return powerLevelContent, nil return powerLevelContent, nil
} }
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID, roomID string) *api.PerformError { func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil { if pErr != nil {
return pErr return pErr
@ -183,7 +190,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
return resErr return resErr
} }
} else { } else {
if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil {
return resErr return resErr
} }
} }
@ -223,7 +230,7 @@ func moveLocalAliases(ctx context.Context,
return nil return nil
} }
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID, roomID string) *api.PerformError { func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError {
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue continue
@ -254,7 +261,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
return resErr return resErr
} }
} else { } else {
if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil {
return resErr return resErr
} }
} }
@ -495,7 +502,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil return eventsToMake, nil
} }
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError {
var err error var err error
var builtEvents []*gomatrixserverlib.HeaderedEvent var builtEvents []*gomatrixserverlib.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
@ -519,7 +526,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion))
if err != nil { if err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
@ -547,7 +554,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: r.Cfg.Matrix.ServerName, Origin: userDomain,
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
} }
@ -668,6 +675,7 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC
func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) sendHeaderedEvent(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName,
headeredEvent *gomatrixserverlib.HeaderedEvent, headeredEvent *gomatrixserverlib.HeaderedEvent,
sendAsServer string, sendAsServer string,
) *api.PerformError { ) *api.PerformError {
@ -675,7 +683,7 @@ func (r *Upgrader) sendHeaderedEvent(
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: headeredEvent, Event: headeredEvent,
Origin: r.Cfg.Matrix.ServerName, Origin: serverName,
SendAsServer: sendAsServer, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil {
@ -689,6 +697,7 @@ func (r *Upgrader) sendHeaderedEvent(
func (r *Upgrader) buildEvent( func (r *Upgrader) buildEvent(
builder *gomatrixserverlib.EventBuilder, builder *gomatrixserverlib.EventBuilder,
serverName gomatrixserverlib.ServerName,
provider gomatrixserverlib.AuthEventProvider, provider gomatrixserverlib.AuthEventProvider,
evTime time.Time, evTime time.Time,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
@ -703,7 +712,7 @@ func (r *Upgrader) buildEvent(
} }
builder.AuthEvents = refs builder.AuthEvents = refs
event, err := builder.Build( event, err := builder.Build(
evTime, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, evTime, serverName, r.Cfg.Matrix.KeyID,
r.Cfg.Matrix.PrivateKey, roomVersion, r.Cfg.Matrix.PrivateKey, roomVersion,
) )
if err != nil { if err != nil {

View file

@ -702,7 +702,7 @@ func (r *Queryer) QueryPublishedRooms(
} }
return err return err
} }
rooms, err := r.DB.GetPublishedRooms(ctx) rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkID, req.IncludeAllNetworks)
if err != nil { if err != nil {
return err return err
} }

View file

@ -27,18 +27,19 @@ const (
RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents"
// Perform operations // Perform operations
RoomserverPerformInvitePath = "/roomserver/performInvite" RoomserverPerformInvitePath = "/roomserver/performInvite"
RoomserverPerformPeekPath = "/roomserver/performPeek" RoomserverPerformPeekPath = "/roomserver/performPeek"
RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" RoomserverPerformUnpeekPath = "/roomserver/performUnpeek"
RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade"
RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformJoinPath = "/roomserver/performJoin"
RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformLeavePath = "/roomserver/performLeave"
RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformBackfillPath = "/roomserver/performBackfill"
RoomserverPerformPublishPath = "/roomserver/performPublish" RoomserverPerformPublishPath = "/roomserver/performPublish"
RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek"
RoomserverPerformForgetPath = "/roomserver/performForget" RoomserverPerformForgetPath = "/roomserver/performForget"
RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom"
RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser"
RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState"
// Query operations // Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
@ -261,6 +262,17 @@ func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
) )
} }
func (h *httpRoomserverInternalAPI) PerformAdminDownloadState(
ctx context.Context,
request *api.PerformAdminDownloadStateRequest,
response *api.PerformAdminDownloadStateResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminDownloadState", h.roomserverURL+RoomserverPerformAdminDownloadStatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
request *api.PerformAdminEvacuateUserRequest, request *api.PerformAdminEvacuateUserRequest,

View file

@ -65,6 +65,11 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser),
) )
internalAPIMux.Handle(
RoomserverPerformAdminDownloadStatePath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", r.PerformAdminDownloadState),
)
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath, RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms),

View file

@ -139,9 +139,9 @@ type Database interface {
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context) ([]string, error) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error)
// Returns whether a given room is published or not. // Returns whether a given room is published or not.
GetPublishedRoom(ctx context.Context, roomID string) (bool, error) GetPublishedRoom(ctx context.Context, roomID string) (bool, error)

View file

@ -0,0 +1,45 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published ADD COLUMN IF NOT EXISTS appservice_id TEXT NOT NULL DEFAULT '';`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_published ADD COLUMN IF NOT EXISTS network_id TEXT NOT NULL DEFAULT '';`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownPublishedAppservice(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published DROP COLUMN IF EXISTS appservice_id;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_published DROP COLUMN IF EXISTS network_id;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
@ -27,31 +28,48 @@ const publishedSchema = `
-- Stores which rooms are published in the room directory -- Stores which rooms are published in the room directory
CREATE TABLE IF NOT EXISTS roomserver_published ( CREATE TABLE IF NOT EXISTS roomserver_published (
-- The room ID of the room -- The room ID of the room
room_id TEXT NOT NULL PRIMARY KEY, room_id TEXT NOT NULL,
-- The appservice ID of the room
appservice_id TEXT NOT NULL,
-- The network_id of the room
network_id TEXT NOT NULL,
-- Whether it is published or not -- Whether it is published or not
published BOOLEAN NOT NULL DEFAULT false published BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (room_id, appservice_id, network_id)
); );
` `
const upsertPublishedSQL = "" + const upsertPublishedSQL = "" +
"INSERT INTO roomserver_published (room_id, published) VALUES ($1, $2) " + "INSERT INTO roomserver_published (room_id, appservice_id, network_id, published) VALUES ($1, $2, $3, $4) " +
"ON CONFLICT (room_id) DO UPDATE SET published=$2" "ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published=$4"
const selectAllPublishedSQL = "" + const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC"
const selectNetworkPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC"
const selectPublishedSQL = "" + const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1" "SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct { type publishedStatements struct {
upsertPublishedStmt *sql.Stmt upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt
selectNetworkPublishedStmt *sql.Stmt
} }
func CreatePublishedTable(db *sql.DB) error { func CreatePublishedTable(db *sql.DB) error {
_, err := db.Exec(publishedSchema) _, err := db.Exec(publishedSchema)
return err if err != nil {
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: published appservice",
Up: deltas.UpPulishedAppservice,
})
return m.Up(context.Background())
} }
func PreparePublishedTable(db *sql.DB) (tables.Published, error) { func PreparePublishedTable(db *sql.DB) (tables.Published, error) {
@ -61,14 +79,15 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) {
{&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.upsertPublishedStmt, upsertPublishedSQL},
{&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL},
{&s.selectPublishedStmt, selectPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL},
{&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *publishedStatements) UpsertRoomPublished( func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, txn *sql.Tx, roomID string, published bool, ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
_, err = stmt.ExecContext(ctx, roomID, published) _, err = stmt.ExecContext(ctx, roomID, appserviceID, networkID, published)
return return
} }
@ -84,10 +103,18 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, txn *sql.Tx, published bool, ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool,
) ([]string, error) { ) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) var rows *sql.Rows
rows, err := stmt.QueryContext(ctx, published) var err error
if networkID != "" {
stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt)
rows, err = stmt.QueryContext(ctx, published, networkID)
} else {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err = stmt.QueryContext(ctx, published, includeAllNetworks)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -722,9 +722,9 @@ func (d *Database) storeEvent(
}, redactionEvent, redactedEventID, err }, redactionEvent, redactedEventID, err
} }
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.PublishedTable.UpsertRoomPublished(ctx, txn, roomID, publish) return d.PublishedTable.UpsertRoomPublished(ctx, txn, roomID, appserviceID, networkID, publish)
}) })
} }
@ -732,8 +732,8 @@ func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, e
return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID) return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID)
} }
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true, includeAllNetworks)
} }
func (d *Database) MissingAuthPrevEvents( func (d *Database) MissingAuthPrevEvents(

View file

@ -0,0 +1,64 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_published RENAME TO roomserver_published_tmp;
CREATE TABLE IF NOT EXISTS roomserver_published (
room_id TEXT NOT NULL,
appservice_id TEXT NOT NULL,
network_id TEXT NOT NULL,
published BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT unique_published_idx PRIMARY KEY (room_id, appservice_id, network_id)
);
INSERT
INTO roomserver_published (
room_id, published
) SELECT
room_id, published
FROM roomserver_published_tmp
;
DROP TABLE roomserver_published_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownPublishedAppservice(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_published RENAME TO roomserver_published_tmp;
CREATE TABLE IF NOT EXISTS roomserver_published (
room_id TEXT NOT NULL PRIMARY KEY,
published BOOLEAN NOT NULL DEFAULT false
);
INSERT
INTO roomserver_published (
room_id, published
) SELECT
room_id, published
FROM roomserver_published_tmp
;
DROP TABLE roomserver_published_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
@ -27,31 +28,49 @@ const publishedSchema = `
-- Stores which rooms are published in the room directory -- Stores which rooms are published in the room directory
CREATE TABLE IF NOT EXISTS roomserver_published ( CREATE TABLE IF NOT EXISTS roomserver_published (
-- The room ID of the room -- The room ID of the room
room_id TEXT NOT NULL PRIMARY KEY, room_id TEXT NOT NULL,
-- The appservice ID of the room
appservice_id TEXT NOT NULL,
-- The network_id of the room
network_id TEXT NOT NULL,
-- Whether it is published or not -- Whether it is published or not
published BOOLEAN NOT NULL DEFAULT false published BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (room_id, appservice_id, network_id)
); );
` `
const upsertPublishedSQL = "" + const upsertPublishedSQL = "" +
"INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" "INSERT INTO roomserver_published (room_id, appservice_id, network_id, published) VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published = $4"
const selectAllPublishedSQL = "" + const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC"
const selectNetworkPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC"
const selectPublishedSQL = "" + const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1" "SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct { type publishedStatements struct {
db *sql.DB db *sql.DB
upsertPublishedStmt *sql.Stmt upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt
selectNetworkPublishedStmt *sql.Stmt
} }
func CreatePublishedTable(db *sql.DB) error { func CreatePublishedTable(db *sql.DB) error {
_, err := db.Exec(publishedSchema) _, err := db.Exec(publishedSchema)
return err if err != nil {
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: published appservice",
Up: deltas.UpPulishedAppservice,
})
return m.Up(context.Background())
} }
func PreparePublishedTable(db *sql.DB) (tables.Published, error) { func PreparePublishedTable(db *sql.DB) (tables.Published, error) {
@ -63,14 +82,15 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) {
{&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.upsertPublishedStmt, upsertPublishedSQL},
{&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL},
{&s.selectPublishedStmt, selectPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL},
{&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *publishedStatements) UpsertRoomPublished( func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, txn *sql.Tx, roomID string, published bool, ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
_, err := stmt.ExecContext(ctx, roomID, published) _, err := stmt.ExecContext(ctx, roomID, appserviceID, networkID, published)
return err return err
} }
@ -86,10 +106,17 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, txn *sql.Tx, published bool, ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool,
) ([]string, error) { ) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) var rows *sql.Rows
rows, err := stmt.QueryContext(ctx, published) var err error
if networkID != "" {
stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt)
rows, err = stmt.QueryContext(ctx, published, networkID)
} else {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err = stmt.QueryContext(ctx, published, includeAllNetworks)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -146,9 +146,9 @@ type Membership interface {
} }
type Published interface { type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published, includeAllNetworks bool) ([]string, error)
} }
type RedactionInfo struct { type RedactionInfo struct {

View file

@ -2,16 +2,18 @@ package tables_test
import ( import (
"context" "context"
"fmt"
"sort" "sort"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
) )
func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Published, close func()) { func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Published, close func()) {
@ -46,10 +48,12 @@ func TestPublishedTable(t *testing.T) {
// Publish some rooms // Publish some rooms
publishedRooms := []string{} publishedRooms := []string{}
asID := ""
nwID := ""
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
published := i%2 == 0 published := i%2 == 0
err := tab.UpsertRoomPublished(ctx, nil, room.ID, published) err := tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, published)
assert.NoError(t, err) assert.NoError(t, err)
if published { if published {
publishedRooms = append(publishedRooms, room.ID) publishedRooms = append(publishedRooms, room.ID)
@ -61,19 +65,36 @@ func TestPublishedTable(t *testing.T) {
sort.Strings(publishedRooms) sort.Strings(publishedRooms)
// check that we get the expected published rooms // check that we get the expected published rooms
roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, true) roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, publishedRooms, roomIDs) assert.Equal(t, publishedRooms, roomIDs)
// test an actual upsert // test an actual upsert
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
err = tab.UpsertRoomPublished(ctx, nil, room.ID, true) err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, true)
assert.NoError(t, err) assert.NoError(t, err)
err = tab.UpsertRoomPublished(ctx, nil, room.ID, false) err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, false)
assert.NoError(t, err) assert.NoError(t, err)
// should now be false, due to the upsert // should now be false, due to the upsert
publishedRes, err := tab.SelectPublishedFromRoomID(ctx, nil, room.ID) publishedRes, err := tab.SelectPublishedFromRoomID(ctx, nil, room.ID)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, publishedRes) assert.False(t, publishedRes, fmt.Sprintf("expected room %s to be unpublished", room.ID))
// network specific test
nwID = "irc"
room = test.NewRoom(t, alice)
err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, true)
assert.NoError(t, err)
publishedRooms = append(publishedRooms, room.ID)
sort.Strings(publishedRooms)
// should only return the room for network "irc"
allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true, true)
assert.NoError(t, err)
assert.Equal(t, []string{room.ID}, allNWPublished)
// check that we still get all published rooms regardless networkID
roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true, true)
assert.NoError(t, err)
assert.Equal(t, publishedRooms, roomIDs)
}) })
} }

View file

@ -32,6 +32,12 @@ type ClientAPI struct {
// Boolean stating whether catpcha registration is enabled // Boolean stating whether catpcha registration is enabled
// and required // and required
RecaptchaEnabled bool `yaml:"enable_registration_captcha"` RecaptchaEnabled bool `yaml:"enable_registration_captcha"`
// Recaptcha api.js Url, for compatible with hcaptcha.com, etc.
RecaptchaApiJsUrl string `yaml:"recaptcha_api_js_url"`
// Recaptcha div class for sitekey, for compatible with hcaptcha.com, etc.
RecaptchaSitekeyClass string `yaml:"recaptcha_sitekey_class"`
// Recaptcha form field, for compatible with hcaptcha.com, etc.
RecaptchaFormField string `yaml:"recaptcha_form_field"`
// This Home Server's ReCAPTCHA public key. // This Home Server's ReCAPTCHA public key.
RecaptchaPublicKey string `yaml:"recaptcha_public_key"` RecaptchaPublicKey string `yaml:"recaptcha_public_key"`
// This Home Server's ReCAPTCHA private key. // This Home Server's ReCAPTCHA private key.
@ -75,6 +81,18 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
if c.RecaptchaSiteVerifyAPI == "" {
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
}
if c.RecaptchaApiJsUrl == "" {
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
}
if c.RecaptchaFormField == "" {
c.RecaptchaFormField = "g-recaptcha"
}
if c.RecaptchaSitekeyClass == "" {
c.RecaptchaSitekeyClass = "g-recaptcha-response"
}
} }
// Ensure there is any spam counter measure when enabling registration // Ensure there is any spam counter measure when enabling registration
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {

View file

@ -14,6 +14,9 @@ type Global struct {
// The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'.
ServerName gomatrixserverlib.ServerName `yaml:"server_name"` ServerName gomatrixserverlib.ServerName `yaml:"server_name"`
// The secondary server names, used for virtual hosting.
SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"`
// Path to the private key which will be used to sign requests and events. // Path to the private key which will be used to sign requests and events.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`
@ -120,6 +123,18 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Cache.Verify(configErrs, isMonolith) c.Cache.Verify(configErrs, isMonolith)
} }
func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool {
if c.ServerName == serverName {
return true
}
for _, secondaryName := range c.SecondaryServerNames {
if secondaryName == serverName {
return true
}
}
return false
}
type OldVerifyKeys struct { type OldVerifyKeys struct {
// Path to the private key. // Path to the private key.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`

View file

@ -132,7 +132,7 @@ func Enable(
base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
"msc2836_event_relationships", func(req *http.Request) util.JSONResponse { "msc2836_event_relationships", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing, req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing,
) )
if fedReq == nil { if fedReq == nil {
return errResp return errResp

View file

@ -64,7 +64,7 @@ func Enable(
fedAPI := httputil.MakeExternalAPI( fedAPI := httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing, req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing,
) )
if fedReq == nil { if fedReq == nil {
return errResp return errResp

View file

@ -76,6 +76,13 @@ func GetMemberships(
} }
} }
if joinedOnly && !queryRes.IsInRoom {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
}
}
db, err := syncDB.NewDatabaseSnapshot(req.Context()) db, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil { if err != nil {
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -102,19 +109,15 @@ func GetMemberships(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
result, err := db.Events(req.Context(), eventIDs) qryRes := &api.QueryEventsByIDResponse{}
if err != nil { if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("db.Events failed") util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
result := qryRes.Events
if joinedOnly { if joinedOnly {
if !queryRes.IsInRoom {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
}
}
var res getJoinedMembersResponse var res getJoinedMembersResponse
res.Joined = make(map[string]joinedMember) res.Joined = make(map[string]joinedMember)
for _, ev := range result { for _, ev := range result {

View file

@ -101,7 +101,7 @@ func (p *PDUStreamProvider) CompleteSync(
) )
if jerr != nil { if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { if ctxErr := req.Context.Err(); ctxErr != nil || jerr == sql.ErrTxDone {
return from return from
} }
continue continue
@ -216,6 +216,9 @@ func (p *PDUStreamProvider) IncrementalSync(
return newPos return newPos
} }
// Limit the recent events to X when going backwards
const recentEventBackwardsLimit = 100
// nolint:gocyclo // nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse( func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
@ -229,9 +232,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
originalLimit := eventFilter.Limit originalLimit := eventFilter.Limit
if r.Backwards { // If we're going backwards, grep at least X events, this is mostly to satisfy Sytest
eventFilter.Limit = int(r.From - r.To) if r.Backwards && originalLimit < recentEventBackwardsLimit {
eventFilter.Limit = recentEventBackwardsLimit // TODO: Figure out a better way
diff := r.From - r.To
if diff > 0 && diff < recentEventBackwardsLimit {
eventFilter.Limit = int(diff)
}
} }
recentStreamEvents, limited, err := snapshot.RecentEvents( recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, delta.RoomID, r, ctx, delta.RoomID, r,
eventFilter, true, true, eventFilter, true, true,
@ -242,8 +251,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
} }
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back snapshot.StreamEventsToEvents(device, recentStreamEvents),
gomatrixserverlib.TopologicalOrderByPrevEvents,
)
prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents) prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil { if err != nil {
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
@ -254,10 +265,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return r.To, nil return r.To, nil
} }
// Sort the events so that we can pick out the latest events from both sections.
recentEvents = gomatrixserverlib.HeaderedReverseTopologicalOrdering(recentEvents, gomatrixserverlib.TopologicalOrderByPrevEvents)
delta.StateEvents = gomatrixserverlib.HeaderedReverseTopologicalOrdering(delta.StateEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
// Work out what the highest stream position is for all of the events in this // Work out what the highest stream position is for all of the events in this
// room that were returned. // room that were returned.
latestPosition := r.To latestPosition := r.To
@ -305,6 +312,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
limited = true limited = true
} }
// Now that we've filtered the timeline, work out which state events are still
// left. Anything that appears in the filtered timeline will be removed from the
// "state" section and kept in "timeline".
delta.StateEvents = gomatrixserverlib.HeaderedReverseTopologicalOrdering(
removeDuplicates(delta.StateEvents, recentEvents),
gomatrixserverlib.TopologicalOrderByAuthEvents,
)
if len(delta.StateEvents) > 0 { if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
} }
@ -498,7 +513,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents events := recentEvents
// Only apply history visibility checks if the response is for joined rooms // Only apply history visibility checks if the response is for joined rooms
@ -512,7 +526,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
limited = limited && len(events) == len(recentEvents) limited = limited && len(events) == len(recentEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
if stateFilter.LazyLoadMembers { if stateFilter.LazyLoadMembers {
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -40,4 +40,9 @@ Accesing an AS-hosted room alias asks the AS server
Guest users can join guest_access rooms Guest users can join guest_access rooms
# This will fail in HTTP API mode, so blacklisted for now # This will fail in HTTP API mode, so blacklisted for now
If a device list update goes missing, the server resyncs on the next one
If a device list update goes missing, the server resyncs on the next one
# Might be a bug in the test because leaves do appear :-(
Leaves are present in non-gapped incremental syncs

View file

@ -699,7 +699,7 @@ We do send redundant membership state across incremental syncs if asked
Rejecting invite over federation doesn't break incremental /sync Rejecting invite over federation doesn't break incremental /sync
Gapped incremental syncs include all state changes Gapped incremental syncs include all state changes
Old leaves are present in gapped incremental syncs Old leaves are present in gapped incremental syncs
Leaves are present in non-gapped incremental syncs #Leaves are present in non-gapped incremental syncs
Members from the gap are included in gappy incr LL sync Members from the gap are included in gappy incr LL sync
Presence can be set from sync Presence can be set from sync
/state returns M_NOT_FOUND for a rejected message event /state returns M_NOT_FOUND for a rejected message event
@ -757,4 +757,6 @@ Can get rooms/{roomId}/messages for a departed room (SPEC-216)
Local device key changes appear in /keys/changes Local device key changes appear in /keys/changes
Can get rooms/{roomId}/members at a given point Can get rooms/{roomId}/members at a given point
Can filter rooms/{roomId}/members Can filter rooms/{roomId}/members
Current state appears in timeline in private history with many messages after Current state appears in timeline in private history with many messages after
AS can publish rooms in their own list
AS and main public room lists are separate

View file

@ -36,6 +36,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
Monolithic: true, Monolithic: true,
}) })
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
cfg.FederationAPI.KeyPerspectives = nil
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key
@ -106,6 +107,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat
} }
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
cfg.SyncAPI.Fulltext.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true
cfg.FederationAPI.KeyPerspectives = nil
base := base.NewBaseDendrite(cfg, "Tests") base := base.NewBaseDendrite(cfg, "Tests")
js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream)
return base, js, jc return base, js, jc

View file

@ -318,8 +318,9 @@ type QuerySearchProfilesResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct { type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account AccountType AccountType // Required: whether this is a guest or user account
Localpart string // Required: The localpart for this account. Ignored if account type is guest. Localpart string // Required: The localpart for this account. Ignored if account type is guest.
ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account Password string // optional: if missing then this account will be a passwordless account
@ -360,7 +361,8 @@ type PerformLastSeenUpdateResponse struct {
// PerformDeviceCreationRequest is the request for PerformDeviceCreation // PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct { type PerformDeviceCreationRequest struct {
Localpart string Localpart string
AccessToken string // optional: if blank one will be made on your behalf ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used
AccessToken string // optional: if blank one will be made on your behalf
// optional: if nil an ID is generated for you. If set, replaces any existing device session, // optional: if nil an ID is generated for you. If set, replaces any existing device session,
// which will generate a new access token and invalidate the old one. // which will generate a new access token and invalidate the old one.
DeviceID *string DeviceID *string
@ -384,7 +386,8 @@ type PerformDeviceCreationResponse struct {
// PerformAccountDeactivationRequest is the request for PerformAccountDeactivation // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation
type PerformAccountDeactivationRequest struct { type PerformAccountDeactivationRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used
} }
// PerformAccountDeactivationResponse is the response for PerformAccountDeactivation // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation
@ -434,6 +437,18 @@ type Device struct {
AccountType AccountType AccountType AccountType
} }
func (d *Device) UserDomain() gomatrixserverlib.ServerName {
_, domain, err := gomatrixserverlib.SplitID('@', d.UserID)
if err != nil {
// This really is catastrophic because it means that someone
// managed to forge a malformed user ID for a device during
// login.
// TODO: Is there a better way to deal with this than panic?
panic(err)
}
return domain
}
// Account represents a Matrix account on this home server. // Account represents a Matrix account on this home server.
type Account struct { type Account struct {
UserID string UserID string
@ -577,7 +592,9 @@ type Notification struct {
} }
type PerformSetAvatarURLRequest struct { type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string Localpart string
ServerName gomatrixserverlib.ServerName
AvatarURL string
} }
type PerformSetAvatarURLResponse struct { type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"` Profile *authtypes.Profile `json:"profile"`
@ -606,7 +623,9 @@ type QueryAccountByPasswordResponse struct {
} }
type PerformUpdateDisplayNameRequest struct { type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string Localpart string
ServerName gomatrixserverlib.ServerName
DisplayName string
} }
type PerformUpdateDisplayNameResponse struct { type PerformUpdateDisplayNameResponse struct {

View file

@ -46,9 +46,9 @@ import (
type UserInternalAPI struct { type UserInternalAPI struct {
DB storage.Database DB storage.Database
SyncProducer *producers.SyncAPI SyncProducer *producers.SyncAPI
Config *config.UserAPI
DisableTLSValidation bool DisableTLSValidation bool
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS // AppServices is the list of all registered AS
AppServices []config.ApplicationService AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI KeyAPI keyapi.UserKeyAPI
@ -62,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot update account data of remote users (server name %s)", domain)
} }
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
@ -104,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil return nil
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return nil return nil
} }
@ -171,6 +171,11 @@ func addUserToRoom(
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
serverName := req.ServerName
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
// XXXX: Use the server name here
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
@ -188,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = &api.Account{ res.Account = &api.Account{
AppServiceID: req.AppServiceID, AppServiceID: req.AppServiceID,
Localpart: req.Localpart, Localpart: req.Localpart,
ServerName: a.ServerName, ServerName: serverName,
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
AccountType: req.AccountType, AccountType: req.AccountType,
} }
return nil return nil
@ -235,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
} }
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
serverName := req.ServerName
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
_ = serverName
// XXXX: Use the server name here
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart, "localpart": req.Localpart,
"device_id": req.DeviceID, "device_id": req.DeviceID,
@ -259,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain)
} }
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
@ -392,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
} }
prof, err := a.DB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil { if err != nil {
@ -443,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
} }
devs, err := a.DB.GetDevicesByLocalpart(ctx, local) devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil { if err != nil {
@ -460,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query account data of remote users (server name %s)", domain)
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
@ -509,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
} }
return err return err
} }
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return err return err
} }
if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil { if err != nil {
return err return err
@ -547,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService, AccountType: api.AccountTypeAppService,
} }
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -572,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
serverName := req.ServerName
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %q not locally configured", serverName)
}
evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
} }
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
@ -584,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
} }
deviceReq := &api.PerformDeviceDeletionRequest{ deviceReq := &api.PerformDeviceDeletionRequest{
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
} }
deviceRes := &api.PerformDeviceDeletionResponse{} deviceRes := &api.PerformDeviceDeletionResponse{}
if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil {

View file

@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain)
} }
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil { if err != nil {
@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if err != nil { if err != nil {
return err return err
} }
if domain != a.ServerName { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
} }
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil res.Data = nil

View file

@ -76,7 +76,7 @@ func NewInternalAPI(
userAPI := &internal.UserInternalAPI{ userAPI := &internal.UserInternalAPI{
DB: db, DB: db,
SyncProducer: syncProducer, SyncProducer: syncProducer,
ServerName: cfg.Matrix.ServerName, Config: cfg,
AppServices: appServices, AppServices: appServices,
KeyAPI: keyAPI, KeyAPI: keyAPI,
RSAPI: rsAPI, RSAPI: rsAPI,

View file

@ -66,8 +66,8 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
} }
return &internal.UserInternalAPI{ return &internal.UserInternalAPI{
DB: accountDB, DB: accountDB,
ServerName: cfg.Matrix.ServerName, Config: cfg,
}, accountDB, func() { }, accountDB, func() {
close() close()
baseclose() baseclose()